# --------------------------------------------------------
# modified from Hora

import os
import torch
import numpy as np
from isaacgym import gymtorch
from isaacgym import gymapi
from isaacgym.torch_utils import to_torch, unscale, quat_apply, tensor_clamp, torch_rand_float, quat_conjugate, quat_mul, quat_from_angle_axis, scale, get_euler_xyz
from glob import glob
from hora.utils.misc import tprint
from .base.vec_task import VecTask
import math

import sys
sys.path.append("../IsaacGymEnvs2/isaacgymenvs")
import yaml
import argparse

import pytorch_kinematics as pk
import wandb
try:
    from torchvision.ops import box_convert
    from torchvision.transforms.functional import quaternion_from_matrix
except:
    pass
import torch.nn as nn
import json


def batched_index_select(values, indices, dim = 1):
  value_dims = values.shape[(dim + 1):]
  values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
  indices = indices[(..., *((None,) * len(value_dims)))]
  indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
  value_expand_len = len(indices_shape) - (dim + 1)
  values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]

  value_expand_shape = [-1] * len(values.shape)
  expand_slice = slice(dim, (dim + value_expand_len))
  value_expand_shape[expand_slice] = indices.shape[expand_slice]
  values = values.expand(*value_expand_shape)

  dim += value_expand_len
  return values.gather(dim, indices)



def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace




class LeapHandHora(VecTask):
    def __init__(self, config, sim_device, graphics_device_id, headless):
        self.config = config
        # before calling init in VecTask, need to do
        # 1. setup randomization
        self._setup_domain_rand_config(config['env']['randomization'])
        # 2. setup privileged information
        self._setup_priv_option_config(config['env']['privInfo'])
        
        # 4. setup reward
        self._setup_reward_config(config['env']['reward'])
        
        # 3. setup object assets
        self._setup_object_info(config['env']['object'])
        
        self.base_obj_scale = config['env']['baseObjScale']
        self.save_init_pose = config['env']['genGrasps']
        self.aggregate_mode = self.config['env']['aggregateMode']
        self.up_axis = 'z'
        self.reset_z_threshold = self.config['env']['reset_height_threshold']
        self.grasp_cache_name = self.config['env']['grasp_cache_name']
        self.evaluate = self.config['on_evaluation']
        self.priv_info_dict = {
            'obj_position': (0, 3),
            'obj_scale': (3, 4),
            'obj_mass': (4, 5),
            'obj_friction': (5, 6),
            'obj_com': (6, 9),
        }
        
        self.evaluate_for_statistics = self.config['env'].get('evaluateForStatistics', False)
        
        self.evaluate_goal_conditioned = self.config['env'].get('evaluateGoalConditioned', False)
        self.train_goal_conditioned = self.config['env'].get('trainGoalConditioned', False)
        
        self.rot_axis = self.config['env'].get('rotAxis', 'z')
        self.rot_axis_mult = self.config['env'].get('rotAxisMult', -1)
        self.custm_rot_axis = self.config['env'].get('custmRotAxis', False)
        self.custm_rot_axis_idx = self.config['env'].get('custmRotAxisIdx', 6)
        custm_rot_axis_fn = "assets/so3_32_evenly_distributed_axes.json"
        with open(custm_rot_axis_fn, 'r') as f:
            self.custm_rot_axis_list = json.load(f)
        self.custm_rot_axis_tsr = self.custm_rot_axis_list[self.custm_rot_axis_idx]
        self.custm_rot_axis_tsr = torch.tensor(self.custm_rot_axis_tsr, dtype=torch.float32).cuda()
        
        
        self.hand_facing_dir = self.config['env'].get('handFacingDir', 'up')
        self.omni_wrist_ornt = self.config['env'].get('omniWristOrnt', False)
        self.omni_wrist_down_ornt_only = self.config['env'].get('omniWristDownOrntOnly', False)
        self.omni_wrist_horizontal_ornt_only = self.config['env'].get('omniWristHorizontalOrntOnly', False)
        self.omni_wrist_horizontal_ornt_rot_x = self.config['env'].get('omniWristHorizontalOrntRotX', False)
        self.omni_wrist_lower_half_only = self.config['env'].get('omniWristLowerHalfOnly', False)
        
        self.disable_obj_gravity = self.config['env'].get('disableObjGravity', False)
        
        self.gravity_val = self.config['env'].get('gravityVal', -9.81)
        
        self.add_fingertip_obs = self.config['env'].get('getFingertipObs', False)
        self.add_fingertip_ornt_obs = self.config['env'].get('addFingertipOrntObs', False)
        
        self.adjustable_rot_vel = self.config['env'].get('adjustableRotVel', False)
        self.add_fingertip_state_vel_obs = self.config['env'].get('addFingertipStateVelObs', False)
        self.add_object_state_obs = self.config['env'].get('addObjectStateObs', False)
        
        self.schedule_mass_upper_bound = self.config['env'].get('scheduleMassUperBound', False)
        self.schedule_mass_upper_min = self.config['env'].get('scheduleMassUperBoundMin', 0.05)
        self.schedule_mass_upper_max = self.config['env'].get('scheduleMassUperBoundMax', 0.05)
        self.mass_upper_bound_warming_up_steps = self.config['env'].get('scheduleMassUperBoundWarmingUpSteps', 100)
        self.mass_upper_bound_increasing_steps = self.config['env'].get('scheduleMassUperBoundIncreasingSteps', 200)
        self.schedule_mass_upper_bound_step = 0
        
        self.lag_history_len = 32
        self.lag_history_buf_length = self.config['env'].get('lagHistoryBufLength', 3)
        
        
        self.schedule_rot_vel_coef = self.config['env'].get('scheduleRotVelCoef', False)
        self.rot_vel_coef_min, self.rot_vel_coef_max = self.config['env'].get('rotVelCoefMin', 0.0001), self.config['env'].get('rotVelCoefMax', 1.0)
        self.schedule_rot_vel_warming_up_steps = self.config['env'].get('scheduleRotVelWarmingUpSteps', 200)
        self.schedule_rot_vel_increasing_steps = self.config['env'].get('scheduleRotVelIncreasingSteps', 1000)
        self.rot_vel_coef = self.rot_vel_coef_min # 
        self.rot_vel_coef_step = 0 # 
        # 
        self.sim_gravity_via_force = self.config['env'].get('simGravityViaForce', False)
        self.schedule_gravity_force = self.config['env'].get('scheduleGravityForce', False)
        self.gravity_force_min, self.gravity_force_max = self.config['env'].get('gravityForceMin', 1.0), self.config['env'].get('gravityForceMax', 9.8)
        self.cur_gravity_force = self.gravity_force_min
        self.schedule_gravity_force_step = 0
        self.schedule_gravity_force_warming_up_steps = self.config['env'].get('scheduleGravityForceWarmingUpSteps', 150)
        self.schedule_gravity_force_increasing_steps = self.config['env'].get('scheduleGravityForceIncreasingSteps', 200)
        
        self.object_downfacing_init_z = self.config['env'].get('objectDownfacingInitZ', 0.375)
        
        self.hand_type = self.config['env'].get('handType', 'leap')
        self.additional_tag = self.config['env'].get('additionalTag', '')
        
        if self.hand_type == 'leap':
            self.fingertips = ["thumb_tip_head", "index_tip_head", "middle_tip_head", "ring_tip_head"]
        elif self.hand_type in ['allegro_public', 'allegro_internal']:
            self.fingertips = ["link_15.0_tip", "link_3.0_tip", "link_7.0_tip", "link_11.0_tip"]
        else:
            raise ValueError(f"hand_type {self.hand_type} not supported")
        
        
        self.add_force_obs = self.config['env'].get('addForceObs', False)
        self.upper_mass_limit = self.config['env']['randomization']['randomizeMassUpper']
        self.add_contact_force_obs = self.config['env'].get('addContactForceObs', False)
        self.add_contact_force_with_binary_contacts = self.config['env'].get('addContactForceWithBinaryContacts', False)
        self.add_obj_goal_observations = self.config['env'].get('addObjGoalObservations', False)
        
        
        self.grasp_to_grasp = self.config['env'].get('graspToGrasp', False)
        self.hand_tracking = self.config['env'].get('handTracking', False)
        self.hand_tracking_target_upd_steps = self.config['env'].get('handTrackingTargetUpdSteps', 50)
        
        
        # 
        self.hand_tracking_nobj = self.config['env'].get('handTrackingNObj', False)
        self.apply_obj_virtual_force = self.config['env'].get('applyObjVirtualForce', False) 
        self.real_to_sim_auto_tune = self.config['env'].get('realToSimAutoTune', False)
        self.real_to_sim_auto_tune_w_obj = self.config['env'].get('realToSimAutoTuneWObj', False)  
        self.preset_pd_gains = self.config['env'].get('presetPDGains', False)
        self.preset_pd_gains_fn = self.config['env'].get('presetPDGainsFn', None)
        
        
        self.add_disturbances_to_init_state = self.config['env'].get('addDisturbancesToInitState', False)
        self.recovery_training = self.config['env'].get('recoveryTraining', False)
        self.recovery_succ_obj_pos_thres = self.config['env'].get('recoverySuccObjPosThres', 0.001)
        
        self.random_start = self.config['env'].get('randomStart', False)
        
        
        ##### Action compensator settings #####
        self.train_action_compensator = self.config['env'].get('trainActionCompensator', False) 
        self.real_play_infos_fn = self.config['env'].get('realPlayInfosFn', None) 
        self.action_compensator_not_using_real_actions = self.config['env'].get('actionCompensatorNotUsingRealActions', False) 
        self.train_action_compensator_uan = self.config['env'].get('trainActionCompensatorUAN', False)
        self.per_joint_action_compensator = self.config['env'].get('perJointActionCompensator', False)
        self.fingertip_only_action_compensator = self.config['env'].get('fingertipOnlyActionCompensator', False)

        self.singe_joint_replay_uan = self.config['env'].get('singleJointReplayUAN', False)
        
        
        self.action_compensator_not_use_history = self.config['env'].get('actionCompensatorNotUseHistory', False)
        self.action_compensator_w_obj = self.config['env'].get('actionCompensatorWObj', False) 
        self.compensator_w_obj_rew_type = self.config['env'].get('compensatorWObjRewType', 'angvel') 
        self.train_action_compensator_w_obj_motion_pred = self.config['env'].get('trainActionCompensatorWMotionPred', False) 
        self.train_action_compensator_w_obj_motion_pred_model_fn = self.config['env'].get('trainActionCompensatorWMotionPredModelFn', None) 
        
        
        
        
        
        self.train_action_compensator_free_hand = self.config['env'].get('trainActionCompensatorFreeHand', False) 
        self.train_action_compensator_w_real_wm = self.config['env'].get('trainActionCompensatorWRealWM', False) 
        self.action_compensator_input_joint_idx = self.config['env'].get('actionCompensatorInputJointIdx', -1)
        self.action_compensator_input_finger_idx = self.config['env'].get('actionCompensatorInputFingerIdx', -1)
        self.action_compensator_output_joint_idx = self.config['env'].get('actionCompensatorOutputJointIdx', -1)
        self.action_compensator_output_finger_idx = self.config['env'].get('actionCompensatorOutputFingerIdx', -1)
        self.train_action_compensator_w_finger_rew = self.config['env'].get('trainActionCompensatorWFingerRew', False)
        self.action_compensator_compute_finger_rew = self.config['env'].get('actionCompensatorComputeFingerRew', False)
        self.action_compensator_add_invaction_rew = self.config['env'].get('actionCompensatorAddInvActionRew', False)
        self.action_compensator_invaction_ckpt_fn = self.config['env'].get('actionCompensatorInvActionCkptFn', '')
        self.action_compensator_w_full_hand = self.config['env'].get('actionCompensatorWFullHand', False)
        self.train_action_compensator_w_real_wm_multi_compensator = self.config['env'].get('trainActionCompensatorWRealWMMultiCompensator', False)
        self.num_action_compensator = self.config['env'].get('numActionCompensator', 1)
        
        self.hierarchical_compensator = self.config['env'].get('hierarchicalCompensator', False)
        self.train_action_compensator_w_residual_wm = self.config['env'].get('trainActionCompensatorWResidualWM', False) 
        self.use_bc_base_policy = self.config['env'].get('useBCBasePolicy', False)
        
        self.wm_per_joint_compensator_full_hand = self.config['env'].get('wmPerJoitCompensatorFullHand', False)
        self.enable_vhacd = self.config['env'].get('enableVhacd', False)
        self.add_obj_features = self.config['env'].get('addObjFeatures', False)
        self.obj_feature_dim = self.config['env'].get('objFeatureDim', 64)
        
        self.specified_wrist_ornt = self.config['env'].get('specifiedWristOrnt', '')
        
        
        if self.action_compensator_input_joint_idx >= 0:
            self.compensator_input_joint_idxes = [self.action_compensator_input_joint_idx]
        elif self.action_compensator_input_finger_idx >= 0:
            self.compensator_input_joint_idxes = [self.action_compensator_input_finger_idx * 4 + i for i in range(4)]
        else:
            self.compensator_input_joint_idxes = [_ for _ in range(16)]
        self.compensator_input_joint_idxes = torch.tensor(self.compensator_input_joint_idxes).long().cuda()
        
        if self.action_compensator_output_joint_idx >= 0:
            self.compensator_output_joint_idxes = [self.action_compensator_output_joint_idx]
        elif self.action_compensator_output_finger_idx >= 0:
            self.compensator_output_joint_idxes = [self.action_compensator_output_finger_idx * 4 + i for i in range(4)]
        else:
            self.compensator_output_joint_idxes = [_ for _ in range(16)]
        self.compensator_output_joint_idxes = torch.tensor(self.compensator_output_joint_idxes).long().cuda()
        self.use_masked_action_compensator = self.config['env'].get('useMaskedActionCompensator', False)
        
        self.tune_bc_model = self.config['env'].get('tuneBCModel', False)
        self.bc_model_history_length = self.config['env'].get('bcModelHistoryLength', 10)
        self.tune_bc_via_compensator_model = self.config['env'].get('tuneBCviaCompensatorModel', False)
        
        self.evaluate_action_add_noise = self.config['env'].get('evaluateActionAddNoise', False)
        self.evaluate_action_noise_std = self.config['env'].get('evaluateActionNoiseStd', 1./24)
        
        
        if self.tune_bc_via_compensator_model:
            self.bc_model_actions = torch.zeros((self.config['env']['numEnvs'], 16), dtype=torch.float32, device='cuda')
        
        if self.train_action_compensator_w_real_wm:
            if self.action_compensator_w_full_hand:
                self._init_compensator_real_world_model_full_hand()
            elif self.action_compensator_input_finger_idx == -1:
                if self.train_action_compensator_w_real_wm_multi_compensator:
                    self._init_multi_compensator_real_world_model_perjoint()
                else:
                    self._init_compensator_real_world_model_perjoint()
            else:
                self._init_compensator_real_world_model()
            
            self.compensator_reset_nn = 0
            
            if self.action_compensator_w_full_hand or self.wm_per_joint_compensator_full_hand:
                if self.train_action_compensator_w_real_wm_multi_compensator:
                    self._init_multi_delta_action_model_full_hand()
                else:
                    self._init_delta_action_model_full_hand()
            else:
                self._init_delta_action_model()
            
            self.compensating_targets = torch.zeros((self.config['env']['numEnvs'], self.compensator_output_joint_idxes.shape[0]), dtype=torch.float32, device='cuda')
            
            if self.train_action_compensator_w_finger_rew or self.action_compensator_compute_finger_rew:
                self.build_pk_chain_finger()
        
            self.compensator_output_joint_idxes = self.sorted_figner_joint_idxes.clone()
            
            self.joint_idx_to_wm_pred_delta_abs = {}
            self.joint_idx_to_delta_action = {}
        
        
        if self.action_compensator_add_invaction_rew:
            self._init_and_load_inverse_dynamics_model() 
        
        
        if self.train_action_compensator:
            self.real_play_infos = np.load(self.real_play_infos_fn, allow_pickle=True).item() 
            
            if 'qpos' in self.real_play_infos:
                real_replay_qpos = self.real_play_infos['qpos']
                real_replay_qtars = self.real_play_infos['qtars']
                real_replay_init_obj_states = self.real_play_infos['init_obj_pose']
            elif 'states' in self.real_play_infos:
                real_replay_qpos = self.real_play_infos['states']
                real_replay_qtars = self.real_play_infos['actions']
                real_replay_init_obj_states = torch.tensor(
                    [0, 0, 0, 0, 0, 0, 1], dtype=torch.float32
                ).unsqueeze(0).repeat(real_replay_qpos.shape[0], 1).contiguous() 
                real_replay_init_obj_states = real_replay_init_obj_states.detach().cpu().numpy()
            else:
                real_replay_qpos = []
                real_replay_qtars = []
                maxx_len = 401
                for traj_idx in self.real_play_infos:
                    cur_traj_idx_to_real_replay_info = self.real_play_infos[traj_idx]
                    cur_traj_idx_to_real_replay_info = np.load(cur_traj_idx_to_real_replay_info, allow_pickle=True).item()
                    cur_real_qpos, cur_real_qtars = cur_traj_idx_to_real_replay_info['qpos'][:maxx_len], cur_traj_idx_to_real_replay_info['qtars'][:maxx_len]
                    # print(f"cur_real_qpos: {cur_real_qpos.shape }, cur_real_qtars: {cur_real_qtars.shape}")
                    real_replay_qpos.append(cur_real_qpos)
                    real_replay_qtars.append(cur_real_qtars)
                real_replay_qpos = np.stack(real_replay_qpos, axis=0)
                real_replay_qtars = np.stack(real_replay_qtars, axis=0) # nn_trajs x nn_ts x nn_dofs #
            self.real_replay_qpos = torch.from_numpy(real_replay_qpos).float().cuda() # nn_trajs x nn_ts x nn_dofs #
            self.real_replay_qtars = torch.from_numpy(real_replay_qtars).float().cuda()  # nn_trajs x nn_ts x nn_dofs #
           
            print(f"real_replay_qpos: {self.real_replay_qpos.size()}, real_replay_qtars: {self.real_replay_qtars.size()}")
            self.envs_replay_qpos = torch.zeros((self.config['env']['numEnvs'], self.real_replay_qpos.size(1), self.real_replay_qpos.size(2)), dtype=torch.float32).cuda()  # nn_envs x nn_ts x nn_dofs #
            self.envs_replay_qtars = torch.zeros((self.config['env']['numEnvs'], self.real_replay_qtars.size(1), self.real_replay_qtars.size(2)), dtype=torch.float32).cuda()  # nn_envs x nn_ts x nn_dofs #
            self.real_replay_init_obj_states = torch.from_numpy(real_replay_init_obj_states).float().cuda()
            if self.train_action_compensator_w_obj_motion_pred:
                self._init_and_load_obj_motion_pred_model()
            
            
        self.delta_actions = None
        
        ##### Openloop replay settings #####
        self.openloop_replay = self.config['env'].get('openloopReplay', False)
        self.openloop_replay_folder = self.config['env'].get('openloopReplayFolder', None)
        self.replay_finger_idx = self.config['env'].get('replayFingerIdx', -1)
        self.replay_joint_idx = self.config['env'].get('replayJointIdx', -1)
        
        if self.openloop_replay:
            if os.path.isdir(self.openloop_replay_folder):
                self.openloop_replay_src_actions = []
                self.openloop_replay_src_states = []
                nn_replay_experiences = self.config['env']['numEnvs']
                for i_env in range(nn_replay_experiences): # openloop replay folder #
                    cur_env_sv_fn = os.path.join(self.openloop_replay_folder, f"env_{i_env}.npy")
                    cur_env_dict = np.load(cur_env_sv_fn, allow_pickle=True).item()
                    cur_env_qpos, cur_env_qtars = cur_env_dict['shadow_hand_dof_pos'], cur_env_dict['shadow_hand_dof_tars']
                    self.openloop_replay_src_states.append(cur_env_qpos)
                    self.openloop_replay_src_actions.append(cur_env_qtars)
                self.openloop_replay_src_states = torch.from_numpy(np.stack(self.openloop_replay_src_states, axis=0)).float().cuda() # nn_envs x nn_ts x nn_dofs #
                self.openloop_replay_src_actions = torch.from_numpy(np.stack(self.openloop_replay_src_actions, axis=0)).float().cuda() # nn_envs x nn_ts x nn_dofs #
            else:
                replay_info_dict = np.load(self.openloop_replay_folder, allow_pickle=True).item()
                test_replay_states = replay_info_dict['shadow_hand_dof_pos']
                test_replay_actions = replay_info_dict['shadow_hand_dof_tars']
                self.openloop_replay_src_states = torch.from_numpy(test_replay_states).float().cuda()
                self.openloop_replay_src_actions = torch.from_numpy(test_replay_actions).float().cuda()
            
            if self.replay_joint_idx >= 0:
                masked_out_joint_idxes = [ i for i in range(self.openloop_replay_src_actions.size(-1)) if i != self.replay_joint_idx ]
            elif self.replay_finger_idx >= 0:
                masked_out_joint_idxes = [ i for i in range(self.openloop_replay_src_actions.size(-1)) if (i < self.replay_finger_idx * 4 and i >= self.replay_finger_idx * 4 + 4) ]
            else:
                masked_out_joint_idxes = None
            
            if masked_out_joint_idxes is not None:
                masked_out_joint_idxes = torch.tensor(masked_out_joint_idxes, dtype=torch.int64, device=self.openloop_replay_src_actions.device)    
                self.openloop_replay_src_states[..., masked_out_joint_idxes] = 0.0
                self.openloop_replay_src_actions[..., masked_out_joint_idxes] = 0.0
        ##### Openloop replay settings #####
        
        
        if self.preset_pd_gains:
            # preset the pd gains -- objtype idx to preset  
            print(f"INFO: Setting the preset PD gains...")
            tot_pd_gains_fn_list = self.preset_pd_gains_fn.split('ANDOBJ')
            self.policy_idx_to_pd_pgains = {}
            self.policy_idx_to_pd_dgains = {}
            self.policy_idx_to_rigid_body_masses = {}
            self.policy_idx_to_rigid_body_inertias = {}
            num_dofs = 16
            for policy_idx, cur_pd_gains_fn in enumerate(tot_pd_gains_fn_list):
                if cur_pd_gains_fn == 'NONE':
                    self.policy_idx_to_pd_pgains[policy_idx] = torch.ones(( num_dofs), dtype=torch.float).cuda() * float(self.config['env']['controller']['pgain'])
                    self.policy_idx_to_pd_dgains[policy_idx] = torch.ones(( num_dofs), dtype=torch.float).cuda() * float(self.config['env']['controller']['dgain'])
                    self.policy_idx_to_rigid_body_masses[policy_idx] = None
                    self.policy_idx_to_rigid_body_inertias[policy_idx] = None
                else:
                    cur_pd_gains_folder = np.load(cur_pd_gains_fn, allow_pickle=True).item()
                    cur_pgains = cur_pd_gains_folder['pgains']
                    cur_dgains = cur_pd_gains_folder['dgains']
                    cur_pgains = torch.from_numpy(cur_pgains).float().cuda()
                    cur_dgains = torch.from_numpy(cur_dgains).float().cuda()
                    
                    self.policy_idx_to_pd_pgains[policy_idx] = cur_pgains
                    self.policy_idx_to_pd_dgains[policy_idx] = cur_dgains
                    
                    if 'rigid_body_masses' in cur_pd_gains_folder:
                        self.policy_idx_to_rigid_body_masses[policy_idx] = torch.from_numpy(cur_pd_gains_folder['rigid_body_masses']).float().cuda()
                    else:
                        self.policy_idx_to_rigid_body_masses[policy_idx] = None
                    if 'rigid_body_inertias' in cur_pd_gains_folder:
                        self.policy_idx_to_rigid_body_inertias[policy_idx] = torch.from_numpy(cur_pd_gains_folder['rigid_body_inertias']).float().cuda()
                    else:
                        self.policy_idx_to_rigid_body_inertias[policy_idx] = None
            
            self.preset_identified_friction_coef = None
            self.preset_identified_friction_coef = 2.16704494
        
        
        
        def get_full_grasp_cache_name(grasp_cache_prefix, inst_idx, obj_scale):
            grasp_cache_fn = f'cache/{grasp_cache_prefix}_{inst_idx}_grasp_50k_s{str(obj_scale).replace(".", "")}.npy'
                
            return grasp_cache_fn
        
        
        ## randomize scale ##
        self.use_multi_objs = self.config['env'].get('useMultiObjs', False)
        self.seperate_inst_grasp_pose = self.config['env']['object']['seperateInstGraspPose']
        
        if self.use_multi_objs:
            self.saved_grasping_states = {}
            multi_obj_specifiedObjIdx = str(self.config['env']['object']['specifiedObjectIdx']).split('ANDOBJ')
            multi_obj_grasp_cache_name_list = self.grasp_cache_name.split('ANDOBJ')
            
            multi_obj_randomize_scale_str = str(self.config['env']['multiObjRandomizeScale']).split('ANDOBJ')
            self.obj_inst_idx_to_scale_list = {}
            
            self.randomize_scale_list = {}
            
            for i_obj_type, cur_obj_type_specifiedObjIdx in enumerate(multi_obj_specifiedObjIdx):
                cur_obj_type_randomize_scale_str = multi_obj_randomize_scale_str[i_obj_type]
                cur_obj_type_randomize_scale_list = cur_obj_type_randomize_scale_str.split('AND')
                cur_obj_type_randomize_scale_list = [ float(cur_scale_str) for cur_scale_str in cur_obj_type_randomize_scale_list ]
                
                cur_obj_inst_idxes = cur_obj_type_specifiedObjIdx.split('AND')
                cur_obj_tot_idxes = [ int(cur_idxxx) for cur_idxxx in cur_obj_inst_idxes ]
                cur_obj_grasp_cache_name = multi_obj_grasp_cache_name_list[i_obj_type]
                cur_offset_obj_idx = len(self.saved_grasping_states)
                
                for ii_inst, i_inst in enumerate(cur_obj_tot_idxes):
                
                    self.saved_grasping_states[i_inst + cur_offset_obj_idx] = {}
                    self.obj_inst_idx_to_scale_list[i_inst + cur_offset_obj_idx] = cur_obj_type_randomize_scale_list
                    
                    for s in cur_obj_type_randomize_scale_list:
                        self.randomize_scale_list[s] = 1
                        
                        cur_obj_cur_inst_grasping_states = []
                        cur_obj_cur_inst_cache_names = cur_obj_grasp_cache_name.split('AND')
                        for cur_grasp_cache_name in cur_obj_cur_inst_cache_names:
                            cur_grasp_cache_fn = get_full_grasp_cache_name(cur_grasp_cache_name, i_inst, s)
                            # print(f"cur_grasp_cache_fn: {cur_grasp_cache_fn}")
                            if not os.path.exists(cur_grasp_cache_fn):
                                continue
                            cur_obj_cur_inst_grasping_states.append(torch.from_numpy(np.load(cur_grasp_cache_fn)).float().cuda())
                        cur_obj_cur_inst_grasping_states = torch.cat(cur_obj_cur_inst_grasping_states, dim=0)
                        self.saved_grasping_states[i_inst + cur_offset_obj_idx][str(s)] = cur_obj_cur_inst_grasping_states
            self.randomize_scale_list = list(self.randomize_scale_list.keys())
            self.randomize_scale_list = [ float(cur_scale_str) for cur_scale_str in self.randomize_scale_list ]
        else:
            if self.randomize_scale and self.scale_list_init:
                if self.seperate_inst_grasp_pose:
                    print(f"Loading from seperate inst grasp poses...")
                    self.saved_grasping_states = {}
                    tot_inst_idxes = str(self.config['env']['object']['specifiedObjectIdx'])
                    tot_inst_idxes = tot_inst_idxes.split('AND'); tot_inst_idxes = [int(cur_idxxx) for cur_idxxx in tot_inst_idxes]
                    # for i_inst in range(self.nn_object_insts):
                    for i_inst in tot_inst_idxes:
                        self.saved_grasping_states[i_inst] = {}
                        
                        for s in self.randomize_scale_list: 
                            
                            if 'AND' not in self.grasp_cache_name:
                                grasp_cache_fn = get_full_grasp_cache_name(self.grasp_cache_name, i_inst, s)
                                if not os.path.exists(grasp_cache_fn):
                                    continue
                                self.saved_grasping_states[i_inst][str(s)] = torch.from_numpy(np.load(
                                    grasp_cache_fn
                                )).float().cuda()
                            else:
                                tot_grasping_states = []
                                tot_grasp_cache_names = self.grasp_cache_name.split('AND')
                                for cur_grasp_cache_name in tot_grasp_cache_names:
                                    cur_grasp_cache_fn = get_full_grasp_cache_name(cur_grasp_cache_name, i_inst, s)
                                    print(f"cur_grasp_cache_fn: {cur_grasp_cache_fn}")
                                    if not os.path.exists(  cur_grasp_cache_fn ):
                                        continue
                                    tot_grasping_states.append(
                                        torch.from_numpy(np.load(cur_grasp_cache_fn)).float().cuda()
                                    )
                                tot_grasping_states = torch.cat(tot_grasping_states, dim=0)
                                self.saved_grasping_states[i_inst][str(s)] = tot_grasping_states
                                    
                else:
                    self.saved_grasping_states = {}
                    for s in self.randomize_scale_list:
                        grasp_cache_fn = f'cache/{self.grasp_cache_name}_grasp_50k_s{str(s).replace(".", "")}.npy'
                        if not os.path.exists(grasp_cache_fn):
                            grasp_cache_fn = f'cache/{self.grasp_cache_name}_grasp_1k_s{str(s).replace(".", "")}.npy'
                        self.saved_grasping_states[str(s)] = torch.from_numpy(np.load(
                            grasp_cache_fn,
                        )).float().cuda()
            else:
                assert self.save_init_pose


        if self.add_obj_features:
            self.envs_obj_features = torch.zeros((self.config['env']['numEnvs'], self.obj_feature_dim), dtype=torch.float32).cuda()
        

        self.tot_reset_nn = 0

        super().__init__(config, sim_device, graphics_device_id, headless)

        self.debug_viz = self.config['env']['enableDebugVis']
        self.max_episode_length = self.config['env']['episodeLength']
        self.dt = self.sim_params.dt

        if self.viewer:
            cam_pos = gymapi.Vec3(0.0, 0.4, 1.5)
            cam_target = gymapi.Vec3(0.0, 0.0, 0.5)
            self.gym.viewer_camera_look_at(self.viewer, None, cam_pos, cam_target)

        # get gym GPU state tensors
        actor_root_state_tensor = self.gym.acquire_actor_root_state_tensor(self.sim)
        dof_state_tensor = self.gym.acquire_dof_state_tensor(self.sim)
        rigid_body_tensor = self.gym.acquire_rigid_body_state_tensor(self.sim)
        net_contact_forces = self.gym.acquire_net_contact_force_tensor(self.sim)

        self.allegro_hand_default_dof_pos = torch.zeros(self.num_allegro_hand_dofs, dtype=torch.float, device=self.device)
        self.dof_state = gymtorch.wrap_tensor(dof_state_tensor)
        self.contact_forces = gymtorch.wrap_tensor(net_contact_forces).view(self.num_envs, -1, 3)
        print(f"Contact Tensr Dimension: {self.contact_forces.shape}")
        self.allegro_hand_dof_state = self.dof_state.view(self.num_envs, -1, 2)[:, :self.num_allegro_hand_dofs]
        self.allegro_hand_dof_pos = self.allegro_hand_dof_state[..., 0]
        self.allegro_hand_dof_vel = self.allegro_hand_dof_state[..., 1]
        

        self.rigid_body_states = gymtorch.wrap_tensor(rigid_body_tensor).view(self.num_envs, -1, 13)
        self.num_bodies = self.rigid_body_states.shape[1]

        self.root_state_tensor = gymtorch.wrap_tensor(actor_root_state_tensor).view(-1, 13)
        
        self.init_root_state_tensor = torch.zeros_like(self.root_state_tensor)
        self.reset_pos_dist_val = 0.04
        
        self.target_root_state_tensor = torch.zeros_like(self.root_state_tensor)
        
        self._refresh_gym()

        if self.hand_tracking:
            self.num_dofs = (self.gym.get_sim_dof_count(self.sim) ) // self.num_envs
        else:
            self.num_dofs = self.gym.get_sim_dof_count(self.sim) // self.num_envs
        

        self.prev_targets = torch.zeros((self.num_envs, self.num_dofs), dtype=torch.float, device=self.device)
        self.cur_targets = torch.zeros((self.num_envs, self.num_dofs), dtype=torch.float, device=self.device)
        self.force_scale = self.config['env'].get('forceScale', 0.0)
        self.random_force_prob_scalar = self.config['env'].get('randomForceProbScalar', 0.0)
        self.force_decay = self.config['env'].get('forceDecay', 0.99)
        self.force_decay_interval = self.config['env'].get('forceDecayInterval', 0.08)
        self.force_decay = to_torch(self.force_decay, dtype=torch.float, device=self.device)
        self.rb_forces = torch.zeros((self.num_envs, self.num_bodies, 3), dtype=torch.float, device=self.device)
        
        
        self.seperate_inst_grasp_pose = self.config['env']['object']['seperateInstGraspPose']
        self.nn_object_insts = self.config['env']['object']['nnInsts']
        
        
        self.change_rot_dir = self.config['env'].get('changeRotDir', False)
        self.change_target_rot_dir = self.config['env'].get('changeTargetRotDir', 'z')
        self.change_rot_dir_period = self.config['env'].get('changeRotDirPeriod', 200)
        self.change_target_rot_dir_mult = self.config['env'].get('changeTargetRotDirMult', 1)
        
        self.randomize_rot_dir = self.config['env'].get('randomizeRotDir', False)
        # self.randomize_rot_dir_period = self.config['env'].get('randomizeRotDirPeriod', 200)
        self.use_preset_rot_dir = self.config['env'].get('usePresetRotDir', False)
        self.preset_rot_dir = torch.tensor([ -1, 0, 0 ], dtype=torch.float32, device=self.device) # tensor of the preset rot dir #
        
        
        self.add_translation = self.config['env'].get('addTranslation', False ) # whether to add the tranlation target
        self.translation_dir = self.config['env'].get('translationDir', 'z')
        self.translation_dir_mult = self.config['env'].get('translationDirMult', 1) # translation direction multiplier
        
        self.finetune_with_action_compensator = self.config['env'].get('finetuneWithActionCompensator', False) # whether to finetune the action compensator 
        self.delta_action_scale = self.config['env'].get('deltaActionScale', 1./48)
        
        if self.finetune_with_action_compensator:
            self._init_and_load_action_compesator()
        
        #### Construct the translation axis buffer ####
        self.trans_dir_buf = torch.zeros((self.num_envs, 3), dtype=torch.float32, device=self.device) # translational direction buffer
        if self.translation_dir == 'x':
            self.trans_dir_buf[..., 0] = self.translation_dir_mult
        elif self.translation_dir == 'y':
            self.trans_dir_buf[..., 1] = self.translation_dir_mult
        elif self.translation_dir == 'z':
            self.trans_dir_buf[..., 2] = self.translation_dir_mult
        #### Construct the translation axis buffer ####
        
        

        self.rot_axis_buf = torch.zeros((self.num_envs, 3), device=self.device, dtype=torch.float)
        
        # Useful buffers (0, 0.1, 0), (0, -0.1, 0)
        self.object_rot_prev = self.object_rot.clone()
        self.object_pos_prev = self.object_pos.clone()
        self.init_pose_buf = torch.zeros((self.num_envs, self.num_dofs), device=self.device, dtype=torch.float)
        self.actions = torch.zeros((self.num_envs, self.num_actions), device=self.device, dtype=torch.float)
        self.torques = torch.zeros((self.num_envs, self.num_actions), device=self.device, dtype=torch.float)
        self.dof_vel_finite_diff = torch.zeros((self.num_envs, self.num_dofs), device=self.device, dtype=torch.float)
        assert type(self.p_gain) in [int, float] and type(self.d_gain) in [int, float], 'assume p_gain and d_gain are only scalars'
        self.p_gain = torch.ones((self.num_envs, self.num_dofs), device=self.device, dtype=torch.float) * self.p_gain
        self.d_gain = torch.ones((self.num_envs, self.num_dofs), device=self.device, dtype=torch.float) * self.d_gain
        
        self.first_level_compensated_targets = torch.zeros((self.num_envs, self.num_dofs), device=self.device, dtype=torch.float)
        
        if self.preset_pd_gains:
            for policy_idx in self.policy_idx_to_env_list:
                cur_env_idxes = self.policy_idx_to_env_list[policy_idx]
                cur_pgains = self.policy_idx_to_pd_pgains[policy_idx]
                cur_dgains = self.policy_idx_to_pd_dgains[policy_idx]
                self.p_gain[cur_env_idxes, :] = cur_pgains.unsqueeze(0).repeat(cur_env_idxes.size(0), 1).contiguous()
                self.d_gain[cur_env_idxes, :] = cur_dgains.unsqueeze(0).repeat(cur_env_idxes.size(0), 1).contiguous()
            
            self.preset_pgains = self.p_gain.clone()
            self.preset_dgains = self.d_gain.clone()
            
        if self.real_to_sim_auto_tune:
            maxx_nn  = 100
            
            if self.real_to_sim_auto_tune_w_obj:
                self.auto_tune_sequences_fn = "cache/auto_tune_seq_dict_cuboidthin_n110_winitobjpose.npy"
                self.auto_tune_sequences = np.load(self.auto_tune_sequences_fn, allow_pickle=True).item()
                self.auto_tune_actions = self.auto_tune_sequences['actions'][:maxx_nn]
                self.auto_tune_states = self.auto_tune_sequences['states'][:maxx_nn]
                self.auto_tune_init_obj_poses = self.auto_tune_sequences['init_obj_poses'][:maxx_nn]
                self.auto_tune_states = torch.from_numpy(self.auto_tune_states).float().to(self.device) 
                self.auto_tune_actions = torch.from_numpy(self.auto_tune_actions).float().to(self.device)
                self.auto_tune_init_obj_poses = torch.from_numpy(self.auto_tune_init_obj_poses).float().to(self.device)
            else:
                self.auto_tune_sequences_fn = "cache/auto_tune_seq_dict_mujoco_cuboidthin_n1000.npy"
                self.auto_tune_sequences = np.load(self.auto_tune_sequences_fn, allow_pickle=True).item()
                self.auto_tune_actions = self.auto_tune_sequences['actions'][:maxx_nn]
                self.auto_tune_states = self.auto_tune_sequences['states'][:maxx_nn]
                self.auto_tune_states = torch.from_numpy(self.auto_tune_states).float().to(self.device) 
                self.auto_tune_actions = torch.from_numpy(self.auto_tune_actions).float().to(self.device)
            
            
            self.testing_traj_idx = 0
            self.testing_traj_ts = 0
            self.tested_envs_states = torch.zeros((self.num_envs, self.auto_tune_states.size(0), self.auto_tune_states.size(1), self.auto_tune_states.size(2)), dtype=torch.float32).to(self.device) 
            self.tested_envs_obj_states = torch.zeros((self.num_envs, self.auto_tune_states.size(0), self.auto_tune_states.size(1), 7), dtype=torch.float32).to(self.device) 
            
        
        
        self.rew_buf_aux_pose_guidance = torch.zeros((self.num_envs, ), device=self.device, dtype=torch.float)
        self.rew_buf_aux_pose_guidance_bonus = torch.zeros((self.num_envs, ), device=self.device, dtype=torch.float)
        self.rew_buf_wo_aux = torch.zeros((self.num_envs, ), device=self.device, dtype=torch.float)
        
        self.rotr_buf = torch.zeros((self.num_envs, ), device=self.device, dtype=torch.float)
        self.rotp_buf = torch.zeros((self.num_envs, ), device=self.device, dtype=torch.float)

        # Unit vector buffers # 
        self.x_unit_tensor = to_torch([1, 0, 0], dtype=torch.float, device=self.device).repeat((self.num_envs, 1))
        self.y_unit_tensor = to_torch([0, 1, 0], dtype=torch.float, device=self.device).repeat((self.num_envs, 1))
        self.z_unit_tensor = to_torch([0, 0, 1], dtype=torch.float, device=self.device).repeat((self.num_envs, 1))
        self.rnd_rot_tensor = to_torch([0, 0, 0, 1],    dtype=torch.float, device=self.device).repeat((self.num_envs, 1))
        self.reset_z_threshold_tensor = torch.ones((self.num_envs,), device=self.device, dtype=torch.float32) * self.reset_z_threshold
        
        self.lagging_obs_length = 30
        self.obs_buf_lag_history_qpos = torch.zeros((self.num_envs, self.lagging_obs_length, 16), device=self.device, dtype=torch.float32)
        self.obs_buf_lag_history_qtars = torch.zeros((self.num_envs, self.lagging_obs_length, 16), device=self.device, dtype=torch.float32) # nn envs x hist len x 16 #
        self.obs_buf_lag_history_compensated_qtars = torch.zeros((self.num_envs, self.lagging_obs_length, 16), device=self.device, dtype=torch.float32)
        self.cur_fullhand_compensated_target = torch.zeros((self.num_envs, 16), device=self.device, dtype=torch.float32)
        
        
        
        
        if self.hand_tracking:
            self.hand_tracking_targets = torch.zeros((self.num_envs, 16), device=self.device, dtype=torch.float32)
            self.hand_tracking_period_count = torch.zeros((self.num_envs, ), device=self.device, dtype=torch.int32)
        
        if self.adjustable_rot_vel:
            self.rot_axis_tsr = to_torch([0, 0, 0], dtype=torch.float, device=self.device)
            if self.rot_axis == 'x':
                self.rot_axis_tsr[0] = self.rot_axis_mult
            elif self.rot_axis == 'y':
                self.rot_axis_tsr[1] = self.rot_axis_mult
            elif self.rot_axis == 'z':
                self.rot_axis_tsr[2] = self.rot_axis_mult
            self.change_rot_axis_period = 200
            self.rot_axis_step = 0
            self.max_rot_vel, self.min_rot_vel = 0.4, 0.1
            self.envs_rot_vel = torch.rand(self.num_envs).to(self.device).float() * (self.max_rot_vel - self.min_rot_vel) + self.min_rot_vel

        if self.add_force_obs:
            sensor_tensor = self.gym.acquire_force_sensor_tensor(self.sim)
            self.vec_sensor_tensor = gymtorch.wrap_tensor(sensor_tensor).view(self.num_envs, len(self.fingertips) * 6)

            dof_force_tensor = self.gym.acquire_dof_force_tensor(self.sim)
            self.dof_force_tensor = gymtorch.wrap_tensor(dof_force_tensor).view(self.num_envs,
                                                    -1)
            self.dof_force_tensor = self.dof_force_tensor[:, :self.num_allegro_hand_dofs]
            

        if self.grasp_to_grasp:
            self.goal_object_pose = torch.zeros((self.num_envs, 7), dtype=torch.float32, device=self.device)
            self.goal_hand_pose = torch.zeros((self.num_envs, 16), dtype=torch.float32, device=self.device)
        

        # debug and understanding statistics
        self.env_timeout_counter = to_torch(np.zeros((len(self.envs)))).long().to(self.device)  # max 10 (10000 envs)
        self.stat_sum_rewards = 0
        self.stat_sum_rotate_rewards = 0
        self.stat_sum_episode_length = 0
        self.stat_sum_obj_linvel = 0
        self.stat_sum_torques = 0
        self.env_evaluated = 0
        self.max_evaluate_envs = self.config['env'].get('maxEvaluateEnvs', 500000) #  500000
        
        self.ref_ts = 0
        self.test = self.config['env']['test']
        if self.test:
            self.ts_to_reset_info = {}
        
        
        if self.evaluate:
            self.maxx_episode_length = self.max_episode_length  + 0 # 500
            self.sv_cache_nn = 0
            self.shadow_hand_dof_pos_buf = torch.zeros((self.num_envs, self.maxx_episode_length, self.num_allegro_hand_dofs), device=self.device, dtype=torch.float32)
            self.object_pose_buf = torch.zeros((self.num_envs, self.maxx_episode_length, 7), device=self.device, dtype=torch.float32)
            self.shadow_hand_dof_tars_buf = torch.zeros((self.num_envs, self.maxx_episode_length, self.num_allegro_hand_dofs), device=self.device, dtype=torch.float32)
            self.rot_axis_totep_buf = torch.zeros((self.num_envs, self.maxx_episode_length, 3)    , device=self.device, dtype=torch.float32)
            
             
            
            self.reward_buf = torch.zeros((self.num_envs, self.maxx_episode_length, 1), device=self.device, dtype=torch.float32)
            
            self.ep_rotr_buf = torch.zeros((self.num_envs, self.maxx_episode_length, 1), device=self.device, dtype=torch.float32)
            self.ep_rotp_buf = torch.zeros((self.num_envs, self.maxx_episode_length, 1), device=self.device, dtype=torch.float32)
            
            self.value_buf = torch.zeros((self.num_envs, self.maxx_episode_length, 1), device=self.device, dtype=torch.float32)
            
            self.value_vals = torch.zeros((self.num_envs, 1), device=self.device, dtype=torch.float32)
            self.extrin_dim = 8
            self.extrin_buf = torch.zeros((self.num_envs, self.maxx_episode_length, self.extrin_dim), device=self.device, dtype=torch.float32)
            
            self.hand_root_ornt_buf = torch.zeros((self.num_envs, self.maxx_episode_length, 4), device=self.device, dtype=torch.float32) 
            
            if self.evaluate_goal_conditioned:
                self.target_obj_pose_buf = torch.zeros((self.num_envs, 4), device=self.device, dtype=torch.float32)
            if self.evaluate_for_statistics:
                self.evaluated_progress_length = []
                self.evaluated_ep_rotr = []
                self.evaluated_ep_rotp = []
                self.evaluated_ep_rew = []
                self.evaluated_ep_goal_cond_succ = []
                self.evaluated_orientation_diff = []
        if self.train_goal_conditioned:
            self.target_obj_pose_buf = torch.zeros((self.num_envs, 4), device=self.device, dtype=torch.float32)
                
            
    
    
    def build_pk_chain_finger(self,):
        leap_urdf_path = "assets/leap_hand/leap_hand_right.urdf"
        chain = pk.build_chain_from_urdf(open(leap_urdf_path).read()) 
        chain = chain.to(dtype=torch.float32, device='cuda')
        
        self.chain = chain
        self.isaac_order_to_pk_order = [_ for _ in range(4)] + [_ + 8 for _ in range(0, 8)] + [4, 5, 6, 7]
        self.isaac_order_to_pk_order = torch.tensor(self.isaac_order_to_pk_order, dtype=torch.long).cuda()
        self.fingertip_names = [
            'index_tip_head', 'thumb_tip_head', 'middle_tip_head', 'ring_tip_head'
        ]
    
    
    def forward_pk_chain_for_finger_pos(self, joint_angles):
        
        
        pk_joint_angles = joint_angles[..., self.isaac_order_to_pk_order]
        tg_batch = self.chain.forward_kinematics(pk_joint_angles)
        
        finger_trans_matrix_list = []
        finger_rot_matrix_list = []
        for i_finger, finger_name in enumerate(self.fingertip_names):
            finger_trans_matrix = tg_batch[finger_name].get_matrix()[:, :3, 3]
            finger_rot_matrix = tg_batch[finger_name].get_matrix()[:, :3, :3]
            finger_trans_matrix_list.append(finger_trans_matrix)
            finger_rot_matrix_list.append(finger_rot_matrix)
        
        
        figner_trans_per_matrix = torch.stack(finger_trans_matrix_list, dim=1)
        finger_trans_matrix_list = torch.cat(finger_trans_matrix_list, dim=-1)
        finger_rot_matrix_list = torch.cat(finger_rot_matrix_list, dim=-1)
        
        return finger_trans_matrix_list, finger_rot_matrix_list, figner_trans_per_matrix
    
    
    
    def _init_and_load_policy_model(self, ):
        from hora.algo.models.models import ActorCritic
        from hora.algo.models.running_mean_std import RunningMeanStd
        
        num_actions = self.config['env']['numActions']
        nn_obs_policy = self.config['env']['numObservations']
        
        net_config = {
            'actor_units':  [512, 256, 128],
            'priv_mlp_units': [256, 128, 8],
            'actions_num': num_actions,
            'input_shape': (nn_obs_policy, ),
            'priv_info': True,
            'proprio_adapt': False,
            'priv_info_dim': 9,
        }
        
        self.policy_model = ActorCritic(net_config)
        self.policy_model.to(self.device)
        self.policy_model.eval()
        self.policy_running_mean_std = RunningMeanStd((nn_obs_policy, )).to(self.device)
        self.policy_running_mean_std.eval()
        
        policy_ckpt_fn = self.config['env'].get('policyCheckpoint', '')
        if len(policy_ckpt_fn) > 0 and os.path.exists(policy_ckpt_fn):
            print(f"Loading policy from {policy_ckpt_fn}")
            policy_ckpt = torch.load(policy_ckpt_fn)
            self.policy_model.load_state_dict(policy_ckpt['model'])
            self.policy_running_mean_std.load_state_dict(policy_ckpt['running_mean_std'])
    
    def _inference_policy_model(self,):
        policy_model_input_dict = {
            'obs': self.policy_running_mean_std(self.obs_buf),
            'priv_info': self.priv_info_buf
        }
        action = self.policy_model.act_inference(policy_model_input_dict)
        action = torch.clamp(action, -1.0, 1.0)
        return action
        
        
    def _init_and_load_action_compesator(self, ):
        from hora.algo.models.models import ActorCritic
        from hora.algo.models.running_mean_std import RunningMeanStd
        num_actions = self.config['env']['numActions']
        
        if self.train_action_compensator_uan:
            nn_obs_act_compensator = 656
        elif self.fingertip_only_action_compensator:
            nn_obs_act_compensator = 8
            num_actions = 4
        elif self.action_compensator_not_use_history:
            nn_obs_act_compensator = 16 + num_actions
        
        else:
            nn_obs_act_compensator = 32 * 3 + num_actions
            
        priv_info = True
        
        nn_obs_act_compensator_running_mean_std = nn_obs_act_compensator + 0
            
        if self.train_action_compensator_uan and self.per_joint_action_compensator:
            nn_obs_act_compensator = 21 if self.train_action_compensator_uan else 2
            num_actions = 1
            priv_info = False
            nn_obs_act_compensator_running_mean_std = 336
        
        
        net_config = { 
            'actor_units':  [512, 256, 128],
            'priv_mlp_units': [256, 128, 8],
            'actions_num': num_actions,
            'input_shape': (nn_obs_act_compensator, ),
            'priv_info': priv_info,
            'proprio_adapt': False,
            'priv_info_dim': 9,
        }
        
        
        self.action_compensator_model = ActorCritic(net_config)
        self.action_compensator_model.train_action_compensator = True
        self.action_compensator_model.per_joint_action_compensator = self.per_joint_action_compensator
        self.action_compensator_model.train_action_compensator_uan = self.train_action_compensator_uan
        self.action_compensator_model.to(self.device)
        self.action_compensator_model.eval()
        
        
        self.action_compensator_running_mean_std = RunningMeanStd((nn_obs_act_compensator_running_mean_std, )).to(self.device)
        self.action_compensator_running_mean_std.eval()
        act_compensator_ckpt_fn = self.config['env'].get('actionCompensatorCheckpoint', '')
        if len(act_compensator_ckpt_fn) > 0 and os.path.exists(act_compensator_ckpt_fn):
            print(f"Loading action compensator from {act_compensator_ckpt_fn}")
            act_compensator_ckpt = torch.load(act_compensator_ckpt_fn)
            self.action_compensator_model.load_state_dict(act_compensator_ckpt['model'])
            self.action_compensator_running_mean_std.load_state_dict(act_compensator_ckpt['running_mean_std'])
    
        self.act_compensator_obs_buf_lag_history = torch.zeros((
            self.num_envs, 80, self.lag_history_len
        ), device=self.device, dtype=torch.float)
        self.compensated_targets = None
        
    def _init_and_load_obj_motion_pred_model(self, ): 
        from ddim.models.diffusion_controlseq import ModelInvDynObjMotionPred
        
        self.invdyn_v2_config_path = 'controlseq.yml'
        with open(os.path.join("../IsaacGymEnvs2/isaacgymenvs/ddim/configs", self.invdyn_v2_config_path), "r") as f:
            config = yaml.safe_load(f)
        invdyn_config = dict2namespace(config)
        
        invdyn_config.invdyn.history_length = 30
        invdyn_config.invdyn.history_obs_dim = 32
        invdyn_config.invdyn.train_obj_motion_pred_model = True
        invdyn_config.invdyn.model_arch = 'resmlp'
        invdyn_config.invdyn.res_blocks = 2
        invdyn_config.invdyn.pred_extrin = False
        invdyn_config.invdyn.res_blocks = 5
        
        self.obj_motion_pred_model = ModelInvDynObjMotionPred(invdyn_config)
        self.obj_motion_pred_model.cuda()
        
        obj_motion_pred_model_ckpt = torch.load(self.train_action_compensator_w_obj_motion_pred_model_fn)[0]
        self.obj_motion_pred_model.load_state_dict(obj_motion_pred_model_ckpt)
        
        self.obj_motion_pred_model.eval()
    
    def _init_and_load_inverse_dynamics_model(self, ):
        from ddim.models.diffusion_controlseq import InverseDynamicsModel
        
        self.invdyn_v2_config_path = 'controlseq.yml'
        with open(os.path.join("../IsaacGymEnvs2/isaacgymenvs/ddim/configs", self.invdyn_v2_config_path), "r") as f:
            config = yaml.safe_load(f)
        invdyn_config = dict2namespace(config)
        
        invdyn_config.invdyn.history_length = 30
        invdyn_config.invdyn.history_obs_dim = 32
        invdyn_config.invdyn.train_obj_motion_pred_model = True
        invdyn_config.invdyn.model_arch = 'resmlp'
        invdyn_config.invdyn.res_blocks = 2
        invdyn_config.invdyn.pred_extrin = False
        invdyn_config.invdyn.res_blocks = 5
        
        invdyn_config.invdyn.wm_history_length = 1
        invdyn_config.invdyn.finger_idx = -1
        invdyn_config.invdyn.joint_idx = -1
        invdyn_config.invdyn.hist_context_length = 0
        invdyn_config.invdyn.train_inverse_dynamics_model = True
        
        self.inverse_dynamics_model = InverseDynamicsModel(invdyn_config)
        self.inverse_dynamics_model.cuda()
        
        inverse_dynamics_model_ckpt = torch.load(self.action_compensator_invaction_ckpt_fn)[0]
        self.inverse_dynamics_model.load_state_dict(inverse_dynamics_model_ckpt)
        
        self.inverse_dynamics_model.eval()

        
    def _get_obj_motion_pred_input(self, ):
        sim_hist_qpos = self.obs_buf_lag_history_qpos.contiguous().view(self.num_envs, -1).contiguous()
        sim_hist_qtars = self.obs_buf_lag_history_qtars.contiguous().view(self.num_envs, -1).contiguous()
        
        tot_hist_progress_buf = []
        for i_ in range(self.lagging_obs_length - 1, -1, -1):
            cur_progress_buf = self.progress_buf + i_
            cur_progress_buf = torch.clamp(cur_progress_buf, 0, self.max_episode_length - 1)
            tot_hist_progress_buf.append(cur_progress_buf)
        tot_hist_progress_buf = torch.stack(tot_hist_progress_buf, dim=1)
        
        cur_real_states = batched_index_select(self.envs_replay_qpos, indices=tot_hist_progress_buf, dim=1) # envs x histleng x 16 #
        cur_real_actions = batched_index_select(self.envs_replay_qtars, indices=tot_hist_progress_buf, dim=1) # envs x histleng x 16 #
        
        cur_real_states = cur_real_states.contiguous().view(self.num_envs, -1).contiguous()
        cur_real_actions = cur_real_actions.contiguous().view(self.num_envs, -1).contiguous()
        
        
        sim_input = torch.cat([sim_hist_qpos, sim_hist_qtars], dim=-1)
        real_input = torch.cat([cur_real_states, cur_real_actions], dim=-1)
        
        sim_input = self.obj_motion_pred_model(sim_input)
        real_input = self.obj_motion_pred_model(real_input)
        
        return sim_input, real_input
    
        
    def _get_finger_world_model_prediction(self, cur_compensated_targets , use_compensated_targets=True):
        
        wm_hist_hand_qpos = self.obs_buf_lag_history_qpos[:, -self.wm_history_length: ] # nn_envs x wm_hist_length x 16 #
        
        if self.wm_history_length == 1:
            wm_hist_hand_qtars = cur_compensated_targets[:, None]
        else:
            wm_hist_hand_qtars = self.obs_buf_lag_history_qtars[:, -self.wm_history_length + 1: ] # nn_envs x wm_hist_length x 16 #
            wm_hist_hand_qtars = torch.cat([wm_hist_hand_qtars, cur_compensated_targets[:, None]], dim=1)
        
        unscaled_wm_hist_hand_qpos = unscale(wm_hist_hand_qpos, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits) # nn_envs x wm_hist_length x 16 #
        unscaled_wm_hist_hand_qtars = unscale(wm_hist_hand_qtars, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits) # nn_envs x wm_hist_length x 16 #
        
        if self.hist_context_length > 0:
            hist_state = self.obs_buf_lag_history_qpos[:, -self.hist_context_length: ]
            hist_action = self.obs_buf_lag_history_qtars[:, -self.hist_context_length: ]
            
            hist_state = hist_state.contiguous().view(hist_state.size(0), -1).contiguous()
            hist_action = hist_action.contiguous().view(hist_action.size(0), -1).contiguous()
        
        tot_pred_next_state = []
        
        
        finger_idx = self.action_compensator_output_finger_idx
        for joint_idx in self.wm_pred_joint_idxes:
            cur_world_model = self.joint_idx_to_wm[joint_idx]
            cur_finger_joint_idxes = [ _ for _ in range(finger_idx * 4, (finger_idx + 1) * 4) ]
            cur_finger_joint_idxes  = torch.tensor(cur_finger_joint_idxes).long().to(self.device) 
            
            cur_finger_wm_hist_hand_qpos = unscaled_wm_hist_hand_qpos[..., cur_finger_joint_idxes]
            cur_finger_wm_hist_hand_qtars = unscaled_wm_hist_hand_qtars[..., cur_finger_joint_idxes]
            
            cur_finger_wm_hist_hand_qpos = cur_finger_wm_hist_hand_qpos.contiguous().view(cur_finger_wm_hist_hand_qpos.size(0), -1).contiguous()
            cur_finger_wm_hist_hand_qtars = cur_finger_wm_hist_hand_qtars.contiguous().view(cur_finger_wm_hist_hand_qtars.size(0), -1).contiguous()
            
            cur_finger_wm_input_dict = {
                'state': cur_finger_wm_hist_hand_qpos, 
                'action': cur_finger_wm_hist_hand_qtars
            }
            
            if self.hist_context_length > 0:
                cur_finger_wm_input_dict.update( {
                    'hist_state': hist_state, 'hist_action': hist_action}
                )
            cur_finger_pred_next_state = cur_world_model(cur_finger_wm_input_dict)
            
            
            if self.train_action_compensator_w_residual_wm:
                cur_prev_world_model = self.joint_idx_to_prev_wm[joint_idx]
                cur_prev_wm_input_dict = {
                    'state': cur_finger_wm_hist_hand_qpos, 
                    'action': cur_finger_wm_hist_hand_qtars
                }
                cur_prev_wm_pred_next_state = cur_prev_world_model(cur_prev_wm_input_dict)
                cur_finger_pred_next_state = cur_finger_pred_next_state + cur_prev_wm_pred_next_state
            
            
            tot_pred_next_state.append(cur_finger_pred_next_state)
        tot_pred_next_state = torch.cat(tot_pred_next_state, dim=1) # nn_envs x nn_fingers x 16 #
        tot_pred_next_state = scale(tot_pred_next_state, self.allegro_hand_dof_lower_limits[self.sorted_figner_joint_idxes], self.allegro_hand_dof_upper_limits[self.sorted_figner_joint_idxes]) # nn envs x joint dof #
        tot_pred_next_state = tot_pred_next_state.contiguous().view(tot_pred_next_state.size(0), -1).contiguous()
        
        
        
        if use_compensated_targets:
            self.real_wm_pred_next_state = tot_pred_next_state.detach()
        else:
            self.real_wm_pred_next_state_orijoints = tot_pred_next_state.detach()
    
    
    def _get_full_hand_world_model_prediction(self, cur_compensated_targets, use_compensated_targets=True):
        self.full_hand_wm.eval()
        cur_wm_hist_qpos = self.obs_buf_lag_history_qpos[:, - self.full_hand_wm_history_length: , ].clone()
        cur_compensated_target =  cur_compensated_targets
        
        cur_wm_hist_qtars = self.obs_buf_lag_history_compensated_qtars[:, -self.full_hand_wm_history_length:, ].clone()
        cur_wm_hist_qpos = cur_wm_hist_qpos.contiguous().view(cur_wm_hist_qpos.size(0), -1).contiguous()
        cur_wm_hist_qtars = cur_wm_hist_qtars.contiguous().view(cur_wm_hist_qtars.size(0), -1).contiguous() #j
        cur_wm_state = torch.cat([cur_wm_hist_qpos, cur_wm_hist_qtars], dim=-1) # nn_envs x ()
        
        # cur pred nex state -- use the full hand wm #
        cur_pred_nex_state = self.full_hand_wm(cur_wm_state, cur_compensated_target, history_extrin=None)
        
        cur_pred_nex_state = scale(cur_pred_nex_state, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)
        if use_compensated_targets:
            self.real_wm_pred_next_state = cur_pred_nex_state.detach()
        else:
            self.real_wm_pred_next_state_orijoints = cur_pred_nex_state.detach()
            
    
    def _get_finger_world_model_prediction_perjoint(self, cur_compensated_targets , use_compensated_targets=True):
        
        wm_hist_hand_qpos = self.obs_buf_lag_history_qpos[:, -self.wm_history_length: ].clone() # nn_envs x wm_hist_length x 16 #
        
        if self.wm_history_length == 1:
            wm_hist_hand_qtars = cur_compensated_targets[:, None]
        else:
            wm_hist_hand_qtars = self.obs_buf_lag_history_qtars[:, -self.wm_history_length + 1: ] # nn_envs x wm_hist_length x 16 #
            wm_hist_hand_qtars = torch.cat([wm_hist_hand_qtars, cur_compensated_targets[:, None]], dim=1)
        
        unscaled_wm_hist_hand_qpos = unscale(wm_hist_hand_qpos, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits) # nn_envs x wm_hist_length x 16 #
        unscaled_wm_hist_hand_qtars = unscale(wm_hist_hand_qtars, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits) # nn_envs x wm_hist_length x 16 #
        
        if self.hist_context_length > 0:
            hist_state = self.obs_buf_lag_history_qpos[:, -self.hist_context_length: ]
            hist_action = self.obs_buf_lag_history_qtars[:, -self.hist_context_length: ]
            
            hist_state = hist_state.contiguous().view(hist_state.size(0), -1).contiguous()
            hist_action = hist_action.contiguous().view(hist_action.size(0), -1).contiguous()
        
        tot_pred_next_state = []
        
        
        # finger_idx = self.action_compensator_output_finger_idx #  self.sorted_finger_idxes[0]
        for joint_idx in self.wm_pred_joint_idxes:
            cur_world_model = self.joint_idx_to_wm[joint_idx]
            cur_finger_joint_idxes = [ joint_idx ]
            cur_finger_joint_idxes  = torch.tensor(cur_finger_joint_idxes).long().to(self.device) 
            
            cur_finger_wm_hist_hand_qpos = unscaled_wm_hist_hand_qpos[..., cur_finger_joint_idxes]
            cur_finger_wm_hist_hand_qtars = unscaled_wm_hist_hand_qtars[..., cur_finger_joint_idxes]
            
            cur_finger_wm_hist_hand_qpos = cur_finger_wm_hist_hand_qpos.contiguous().view(cur_finger_wm_hist_hand_qpos.size(0), -1).contiguous()
            cur_finger_wm_hist_hand_qtars = cur_finger_wm_hist_hand_qtars.contiguous().view(cur_finger_wm_hist_hand_qtars.size(0), -1).contiguous()
            
            cur_finger_wm_input_dict = {
                'state': cur_finger_wm_hist_hand_qpos, 
                'action': cur_finger_wm_hist_hand_qtars
            }
            
            self.joint_idx_to_delta_action_model_input[joint_idx] =  {
                'state': cur_finger_wm_hist_hand_qpos.clone(), 
                'action': cur_finger_wm_hist_hand_qtars.clone()
            }
            
            if self.hist_context_length > 0:
                cur_finger_wm_input_dict.update( {
                    'hist_state': hist_state, 'hist_action': hist_action}
                )
            
            cur_finger_pred_next_state = cur_world_model(cur_finger_wm_input_dict)
            
            if self.train_action_compensator_w_residual_wm:
                cur_prev_world_model = self.joint_idx_to_prev_wm[joint_idx]
                cur_prev_wm_input_dict = {
                    'state': cur_finger_wm_hist_hand_qpos, 
                    'action': cur_finger_wm_hist_hand_qtars
                }
                cur_prev_wm_pred_next_state = cur_prev_world_model(cur_prev_wm_input_dict)
                cur_finger_pred_next_state = cur_finger_pred_next_state + cur_prev_wm_pred_next_state
            
            tot_pred_next_state.append(cur_finger_pred_next_state.clone())
        tot_pred_next_state = torch.cat(tot_pred_next_state, dim=1) # nn_envs x nn_fingers x 16 #
        tot_pred_next_state = scale(tot_pred_next_state, self.allegro_hand_dof_lower_limits[self.wm_pred_joint_idxes], self.allegro_hand_dof_upper_limits[self.wm_pred_joint_idxes]) # nn envs x joint dof #
        tot_pred_next_state = tot_pred_next_state.contiguous().view(tot_pred_next_state.size(0), -1).contiguous()
        self.tot_pred_next_state = tot_pred_next_state.clone()
        if use_compensated_targets:
            self.real_wm_pred_next_state = tot_pred_next_state.detach()
        else:
            self.real_wm_pred_next_state_orijoints = tot_pred_next_state.detach()
    
    
    
    def _inference_act_compensator(self, obs, cur_action):
        
        
        if self.train_action_compensator_uan:
            history_qtars, history_qpos = self.obs_buf_lag_history_qtars[:, -19:].clone(), self.obs_buf_lag_history_qpos[:, -19:].clone() 
            history_qpos = torch.cat(
                [history_qpos, self.allegro_hand_dof_pos.unsqueeze(1)], dim=1
            )
            history_qtars = torch.cat(
                [ history_qtars,  self.cur_targets.unsqueeze(1)], dim=1
            )
            history_e_qpos_to_qtars = history_qtars - history_qpos
            flatten_e = history_e_qpos_to_qtars.contiguous().view(history_e_qpos_to_qtars.size(0), -1).contiguous()
            t_buf = flatten_e
        else:
            prev_obs_buf = self.act_compensator_obs_buf_lag_history[:, 1:].clone()
            joint_noise_matrix = (torch.rand(self.allegro_hand_dof_pos.shape) * 2.0 - 1.0) * self.joint_noise_scale
            cur_obs_buf = unscale(
                joint_noise_matrix.to(self.device) + self.allegro_hand_dof_pos, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits
            ).clone().unsqueeze(1)
            if self.compensated_targets is not None:
                cur_tar_buf = self.compensated_targets[:, None]
            else:
                cur_tar_buf = self.cur_targets[:, None]
            cur_obs_buf = torch.cat([cur_obs_buf, cur_tar_buf], dim=-1)
            self.act_compensator_obs_buf_lag_history[:] = torch.cat([prev_obs_buf, cur_obs_buf], dim=1)

            # refill the initialized buffers
            at_reset_env_ids = self.at_reset_buf.nonzero(as_tuple=False).squeeze(-1)
            self.act_compensator_obs_buf_lag_history[at_reset_env_ids, :, 0:16] = unscale(
                self.allegro_hand_dof_pos[at_reset_env_ids], self.allegro_hand_dof_lower_limits,
                self.allegro_hand_dof_upper_limits
            ).clone().unsqueeze(1)
            self.act_compensator_obs_buf_lag_history[at_reset_env_ids, :, 16:32] = self.allegro_hand_dof_pos[at_reset_env_ids].unsqueeze(1)
            t_buf = (self.act_compensator_obs_buf_lag_history[:, -3:].reshape(self.num_envs, -1)).clone()
            
            
            if self.action_compensator_not_use_history:
                if self.fingertip_only_action_compensator:
                    t_buf = self.act_compensator_obs_buf_lag_history[:, -1, [3, 7, 11, 15]].clone()
                else:
                    t_buf = self.act_compensator_obs_buf_lag_history[:, -1, :16].clone() # action compensator's observation that do not use history #
            
        if self.fingertip_only_action_compensator:
            fingertip_action = cur_action[..., [3, 7, 11, 15]]
            act_compensator_obs = torch.cat(
                [ t_buf,  fingertip_action], dim=-1
            )
        else:
            act_compensator_obs = torch.cat(
                [ t_buf,  cur_action], dim=-1
            )
        
        # if self.train_action_compensator_uan:
        #     act_compensator_obs = torch.cat(
        #         [ act_compensator_obs, torch.zeros((act_compensator_obs.size(0), 320)).to(self.device) ], dim=-1
        #     )
        #     # 
        
        priv_info =  self.priv_info_buf #  self.obs_dict['priv_info']
        act_compensator_input_dict = {
            'obs': self.action_compensator_running_mean_std(act_compensator_obs),
            'priv_info': priv_info
        }
        mu = self.action_compensator_model.act_inference(act_compensator_input_dict)
        mu = torch.clamp(mu, -1.0, 1.0)
        return mu
        
        

    def _create_envs(self, num_envs, spacing, num_per_row):
        print(f"Start creating envs")
        self._create_ground_plane()
        lower = gymapi.Vec3(-spacing, -spacing, 0.0)
        upper = gymapi.Vec3(spacing, spacing, spacing)

        self._create_object_asset()
        
        self.upd_guiding_pose_steps = torch.zeros((self.num_envs,), device=self.device, dtype=torch.float32)

        # set allegro_hand dof properties
        self.num_allegro_hand_dofs = self.gym.get_asset_dof_count(self.hand_asset)
        allegro_hand_dof_props = self.gym.get_asset_dof_properties(self.hand_asset)

        self.allegro_hand_dof_lower_limits = []
        self.allegro_hand_dof_upper_limits = []
        
        if self.real_to_sim_auto_tune and (not self.preset_pd_gains):
            self.envs_dofs_rnd_pgains = []
            self.envs_dofs_rnd_dgains = []

        for i in range(self.num_allegro_hand_dofs):
            self.allegro_hand_dof_lower_limits.append(allegro_hand_dof_props['lower'][i])
            self.allegro_hand_dof_upper_limits.append(allegro_hand_dof_props['upper'][i])
            allegro_hand_dof_props['effort'][i] = 0.5
            if self.torque_control:
                allegro_hand_dof_props['stiffness'][i] = 0.
                allegro_hand_dof_props['damping'][i] = 0.
                allegro_hand_dof_props['driveMode'][i] = gymapi.DOF_MODE_EFFORT
            else:
                allegro_hand_dof_props['stiffness'][i] = self.config['env']['controller']['pgain']
                allegro_hand_dof_props['damping'][i] = self.config['env']['controller']['dgain']
            allegro_hand_dof_props['friction'][i] = 0.01
            allegro_hand_dof_props['armature'][i] = 0.001
            
            if self.real_to_sim_auto_tune and (not self.preset_pd_gains):
                # pgain [1, 5] (with step 0.1 if discrete grid search) (uniform distribution if continuous range search)
                # dgain [0.01, 0.2] (with step 0.01 if discrete grid search) (uniform distribution if continuous range search )
                cur_dof_envs_rnd_pgains = torch.rand(self.num_envs).to(self.device).float() * (5 - 1) + 1 # uniform values drawing from the range (1, 5)
                cur_dof_envs_rnd_dgains = torch.rand(self.num_envs).to(self.device).float() * (0.2 - 0.01) + 0.01 # uniform values drawing from the range (0.01, 0.2)
                self.envs_dofs_rnd_pgains.append(cur_dof_envs_rnd_pgains)
                self.envs_dofs_rnd_dgains.append(cur_dof_envs_rnd_dgains)
        
        if self.real_to_sim_auto_tune and (not self.preset_pd_gains):
            self.envs_dofs_rnd_pgains = torch.stack(self.envs_dofs_rnd_pgains, dim=1)
            self.envs_dofs_rnd_dgains = torch.stack(self.envs_dofs_rnd_dgains, dim=1) # (nn_envs x nn_dofs)

        self.allegro_hand_dof_lower_limits = to_torch(self.allegro_hand_dof_lower_limits, device=self.device)
        self.allegro_hand_dof_upper_limits = to_torch(self.allegro_hand_dof_upper_limits, device=self.device)


        print(f"allegro_hand_dof_lower_limits: {self.allegro_hand_dof_lower_limits}")
        print(f"allegro_hand_dof_upper_limits: {self.allegro_hand_dof_upper_limits}")

        hand_pose, obj_pose = self._init_object_pose()
        
        
        
        self.hand_pose = hand_pose
        
        self.hand_pose_tsr = to_torch([hand_pose.p.x, hand_pose.p.y, hand_pose.p.z,
                              hand_pose.r.x, hand_pose.r.y, hand_pose.r.z, hand_pose.r.w], device=self.device)
        
        
        # compute aggregate size
        self.num_allegro_hand_bodies = self.gym.get_asset_rigid_body_count(self.hand_asset)
        self.num_allegro_hand_shapes = self.gym.get_asset_rigid_shape_count(self.hand_asset)
        max_agg_bodies = self.num_allegro_hand_bodies + 2
        max_agg_shapes = self.num_allegro_hand_shapes + 2

        print(f"max_agg_bodies: {max_agg_bodies}    max_agg_shapes: {max_agg_shapes}")

        self.envs = []

        self.object_init_state = []

        self.hand_indices = []
        self.object_indices = []
        self.envs_obj_indices = []

        allegro_hand_rb_count = self.gym.get_asset_rigid_body_count(self.hand_asset)
        object_rb_count = 1
        self.object_rb_handles = list(range(allegro_hand_rb_count, allegro_hand_rb_count + object_rb_count))

        if self.add_force_obs:
            sensor_pose = gymapi.Transform()
            for ft_handle in self.fingertip_handles:
                print(f"fingertip_handles: {ft_handle}")
                self.gym.create_asset_force_sensor(self.hand_asset, ft_handle, sensor_pose)

        
        self.guiding_pose =  to_torch([0, 0, 0, 1], dtype=torch.float, device=self.device).repeat((self.num_envs, 1)) 
        
        self.envs_obj_points = []
        self.envs_hand_actor_rigid_body_mass_list = []
        self.envs_hand_actor_rigid_body_inertia_values = []
        self.envs_hand_actor_rigid_body_friction_list = []

        self.envs_scale_list = []
        self.envs_policy_idx = []
        self.policy_idx_to_env_list = {}

        for i in range(num_envs):
            # create env instance
            env_ptr = self.gym.create_env(self.sim, lower, upper, num_per_row)
            if self.aggregate_mode >= 1:
                self.gym.begin_aggregate(env_ptr, max_agg_bodies * 100, max_agg_shapes * 100, True)

            # add hand - collision filter = -1 to use asset collision filters set in mjcf loader
            hand_actor = self.gym.create_actor(env_ptr, self.hand_asset, hand_pose, 'hand', i, -1, 0)
            self.gym.set_actor_dof_properties(env_ptr, hand_actor, allegro_hand_dof_props)
            hand_idx = self.gym.get_actor_index(env_ptr, hand_actor, gymapi.DOMAIN_SIM)
            self.hand_indices.append(hand_idx)
            
            
            
            if len(self.specified_obj_idx) > 0:
                # object_type_id = 0
                object_type_id = np.random.choice(len(self.specified_obj_idx))
                obj_index = self.specified_obj_idx[object_type_id]
            else:
                object_type_id = np.random.choice(len(self.object_type_list), p=self.object_type_prob)
                obj_index = object_type_id
                
            cur_obj_type_idx = self.obj_idx_to_obj_type[obj_index]
            self.envs_policy_idx.append(cur_obj_type_idx)
            
            if cur_obj_type_idx not in self.policy_idx_to_env_list:
                self.policy_idx_to_env_list[cur_obj_type_idx] = []
            self.policy_idx_to_env_list[cur_obj_type_idx].append(i)
            
                
            object_asset = self.object_asset_list[object_type_id]
            
            if self.add_obj_features:
                obj_feature = self.asset_feature_dict[self.object_type_list[object_type_id]]
                self.envs_obj_features[i, :] = obj_feature[:]
            
            
            if self.add_zrot_penalty:
                # cur_obj_points = self.pts_assets_dict[objec]
                cur_obj_points = self.object_points_list[object_type_id]
                cur_obj_points_th = torch.from_numpy(cur_obj_points).float().to(self.device)
                self.envs_obj_points.append(cur_obj_points_th)
                
            
            self.envs_obj_indices.append(obj_index)

            if (self.real_to_sim_auto_tune and (not self.real_to_sim_auto_tune_w_obj)) or (self.hand_tracking and self.hand_tracking_nobj) or (self.openloop_replay) or (self.train_action_compensator and (not self.action_compensator_w_obj) or (self.train_action_compensator_w_real_wm and (self.train_action_compensator_free_hand))):
                obj_collision_group = i + self.num_envs
            else:
                obj_collision_group = i

            # object_handle = self.gym.create_actor(env_ptr, object_asset, obj_pose, 'object', i, 0, 0)
            object_handle = self.gym.create_actor(env_ptr, object_asset, obj_pose, 'object', obj_collision_group, 0, 0)
            self.object_init_state.append([
                obj_pose.p.x, obj_pose.p.y, obj_pose.p.z,
                obj_pose.r.x, obj_pose.r.y, obj_pose.r.z, obj_pose.r.w,
                0, 0, 0, 0, 0, 0
            ])
            object_idx = self.gym.get_actor_index(env_ptr, object_handle, gymapi.DOMAIN_SIM)
            self.object_indices.append(object_idx)
            
            
            
            if self.add_aux_pose_guidance:
                cur_obj_rot = to_torch([obj_pose.r.x, obj_pose.r.y, obj_pose.r.z, obj_pose.r.w], dtype=torch.float, device=self.device)
                # get the object rot axis #
                obj_rot_axis = torch.tensor([0, 0, 0], dtype=torch.float, device=self.device)
                if self.rot_axis == 'x':
                    obj_rot_axis[0] = self.rot_axis_mult
                elif self.rot_axis == 'y':
                    obj_rot_axis[1] = self.rot_axis_mult
                elif self.rot_axis == 'z':
                    obj_rot_axis[2] = self.rot_axis_mult
                
                cur_obj_guiding_rot = quat_from_angle_axis(  self.guiding_delta_rot_radian * torch.ones((1,), dtype=torch.float32).to(self.device) , obj_rot_axis.unsqueeze(0))
                # cur_obj_guiding_rot = cur_obj_guiding_rot.squeeze(0)
                cur_obj_guiding_rot = quat_mul(cur_obj_guiding_rot, cur_obj_rot.unsqueeze(0)).squeeze(0) 
                
                self.guiding_pose[i, :] = cur_obj_guiding_rot 
                self.upd_guiding_pose_steps[i] += 1
                
            if self.add_force_obs: 
                self.gym.enable_actor_dof_force_sensors(env_ptr, hand_actor)

            obj_scale = self.base_obj_scale
            if self.randomize_scale:
                if self.use_multi_objs:
                    obj_scale = self.obj_inst_idx_to_scale_list[obj_index][i % len(self.obj_inst_idx_to_scale_list[obj_index])]
                    self.envs_scale_list.append(obj_scale + 0)
                    obj_scale = np.random.uniform(obj_scale - 0.025, obj_scale + 0.025)
                else:
                    num_scales = len(self.randomize_scale_list)
                    obj_scale = np.random.uniform(self.randomize_scale_list[i % num_scales] - 0.025, self.randomize_scale_list[i % num_scales] + 0.025)
            
            self.gym.set_actor_scale(env_ptr, object_handle, obj_scale)
            self._update_priv_buf(env_id=i, name='obj_scale', value=obj_scale) # 
            

            obj_com = [0, 0, 0]
            if self.randomize_com:
                prop = self.gym.get_actor_rigid_body_properties(env_ptr, object_handle)
                assert len(prop) == 1
                obj_com = [np.random.uniform(self.randomize_com_lower, self.randomize_com_upper),
                           np.random.uniform(self.randomize_com_lower, self.randomize_com_upper),
                           np.random.uniform(self.randomize_com_lower, self.randomize_com_upper)]
                prop[0].com.x, prop[0].com.y, prop[0].com.z = obj_com
                self.gym.set_actor_rigid_body_properties(env_ptr, object_handle, prop)
            self._update_priv_buf(env_id=i, name='obj_com', value=obj_com)

            obj_friction = 1.0
            if self.randomize_friction:
                rand_friction = np.random.uniform(self.randomize_friction_lower, self.randomize_friction_upper)
                
                if self.preset_pd_gains and self.preset_identified_friction_coef is not None:
                    rand_friction = self.preset_identified_friction_coef
                
                hand_props = self.gym.get_actor_rigid_shape_properties(env_ptr, hand_actor)
                for p in hand_props:
                    p.friction = rand_friction
                self.gym.set_actor_rigid_shape_properties(env_ptr, hand_actor, hand_props)

                object_props = self.gym.get_actor_rigid_shape_properties(env_ptr, object_handle)
                for p in object_props:
                    p.friction = rand_friction
                self.gym.set_actor_rigid_shape_properties(env_ptr, object_handle, object_props)
                obj_friction = rand_friction
            self._update_priv_buf(env_id=i, name='obj_friction', value=obj_friction)


            if self.real_to_sim_auto_tune:
                hand_prop = self.gym.get_actor_rigid_body_properties(env_ptr, hand_actor)
                rnd_masses = []
                rnd_inertia_values = []
                for p in hand_prop:
                    if p.mass > 0.01:
                        lower_mass = p.mass * 0.5
                        upper_mass = p.mass * 1.5
                        rnd_mass = np.random.uniform(lower_mass, upper_mass)
                        p.mass = rnd_mass
                        rnd_masses.append(rnd_mass)
                    else:
                        rnd_masses.append(p.mass)
                        
                    inertia = p.inertia

                    ixx, ixy, ixz, iyy, iyz, izz = inertia.x.x, inertia.x.y, inertia.x.z, inertia.y.y, inertia.y.z, inertia.z.z
                    i_vals = np.array([ixx, ixy, ixz, iyy, iyz, izz], dtype=np.float32)
                    sign_i_vals = np.sign(i_vals)
                    lower_abs_i_vals = np.abs(i_vals) * 0.5
                    upper_abs_i_vals = np.abs(i_vals) * 1.5
                    rnd_0_1_values = np.random.uniform(0, 1, 6)
                    rnd_i_vals = lower_abs_i_vals + (upper_abs_i_vals - lower_abs_i_vals) * rnd_0_1_values
                    rnd_i_vals = rnd_i_vals * sign_i_vals
                    rnd_inertia_values.append(rnd_i_vals)
                    inertia.x.x, inertia.x.y, inertia.x.z, inertia.y.y, inertia.y.z, inertia.z.z = rnd_i_vals.tolist()
                    inertia.y.x, inertia.z.x, inertia.z.y = inertia.x.y, inertia.x.z, inertia.y.z
                    p.inertia = inertia
                self.gym.set_actor_rigid_body_properties(env_ptr, hand_actor, hand_prop)
                rnd_masses = np.array(rnd_masses, dtype=np.float32)
                rnd_inertia_values = np.array(rnd_inertia_values, dtype=np.float32)
                self.envs_hand_actor_rigid_body_mass_list.append(rnd_masses)
                self.envs_hand_actor_rigid_body_inertia_values.append(rnd_inertia_values)
                
                
                # # Tune hand and object friction #
                rnd_frictions = []
                cur_body_rnd_friction = np.random.uniform(self.randomize_friction_lower, self.randomize_friction_upper)
                hand_shape_prop = self.gym.get_actor_rigid_shape_properties(env_ptr, hand_actor)
                for p in hand_shape_prop:
                    p.friction = cur_body_rnd_friction
                    rnd_frictions.append(cur_body_rnd_friction)
                self.gym.set_actor_rigid_shape_properties(env_ptr, hand_actor, hand_shape_prop)
                
                object_shape_prop = self.gym.get_actor_rigid_shape_properties(env_ptr, object_handle)
                for p in object_shape_prop:
                    p.friction = cur_body_rnd_friction
                    rnd_frictions.append(cur_body_rnd_friction)
                self.gym.set_actor_rigid_shape_properties(env_ptr, object_handle, object_shape_prop)
                self.envs_hand_actor_rigid_body_friction_list.append(rnd_frictions)
                
            # 
            # if self.preset_pd_gains and self.rigid_body_masses is not None:
            if self.preset_pd_gains and self.policy_idx_to_rigid_body_masses[cur_obj_type_idx] is not None:
                hand_prop = self.gym.get_actor_rigid_body_properties(env_ptr, hand_actor)
                rnd_masses = []
                for i_rigid_body in range(self.policy_idx_to_rigid_body_masses[cur_obj_type_idx].size(0)):
                    hand_prop[i_rigid_body].mass = self.policy_idx_to_rigid_body_masses[cur_obj_type_idx][i_rigid_body].detach().cpu().item()
                self.gym.set_actor_rigid_body_properties(env_ptr, hand_actor, hand_prop)
            
            # if self.preset_pd_gains and self.rigid_body_inertias is not None:
            if self.preset_pd_gains and self.policy_idx_to_rigid_body_inertias[cur_obj_type_idx] is not None:
                hand_prop = self.gym.get_actor_rigid_body_properties(env_ptr, hand_actor)
                for i_rigid_body in range(self.policy_idx_to_rigid_body_inertias[cur_obj_type_idx].size(0)):
                    ixx, ixy, ixz, iyy, iyz, izz = self.policy_idx_to_rigid_body_inertias[cur_obj_type_idx][i_rigid_body].detach().cpu().tolist()
                    inertia = hand_prop[i_rigid_body].inertia
                    inertia.x.x, inertia.x.y, inertia.x.z, inertia.y.y, inertia.y.z, inertia.z.z = ixx, ixy, ixz, iyy, iyz, izz
                    inertia.y.x, inertia.z.x, inertia.z.y = ixy, ixz, iyz
                    hand_prop[i_rigid_body].inertia = inertia
                self.gym.set_actor_rigid_body_properties(env_ptr, hand_actor, hand_prop)
            
            

            if self.randomize_mass:
                prop = self.gym.get_actor_rigid_body_properties(env_ptr, object_handle)
                for p in prop: # randomize mass lower and upper --- 50 g to 51 g
                    p.mass = np.random.uniform(self.randomize_mass_lower, self.randomize_mass_upper)
                self.gym.set_actor_rigid_body_properties(env_ptr, object_handle, prop)
                self._update_priv_buf(env_id=i, name='obj_mass', value=prop[0].mass)
            else:
                prop = self.gym.get_actor_rigid_body_properties(env_ptr, object_handle)
                self._update_priv_buf(env_id=i, name='obj_mass', value=prop[0].mass)

            if self.aggregate_mode > 0:
                self.gym.end_aggregate(env_ptr)
            self.envs.append(env_ptr)

        # envs #
        self.envs_obj_indices = to_torch(self.envs_obj_indices, dtype=torch.long, device=self.device) # nn_envs
        self.object_init_state = to_torch(self.object_init_state, device=self.device, dtype=torch.float).view(self.num_envs, 13)
        self.object_rb_handles = to_torch(self.object_rb_handles, dtype=torch.long, device=self.device)
        self.hand_indices = to_torch(self.hand_indices, dtype=torch.long, device=self.device)
        self.object_indices = to_torch(self.object_indices, dtype=torch.long, device=self.device)
        

        self.envs_policy_idx = to_torch(self.envs_policy_idx, dtype=torch.long, device=self.device)
        
        for policy_idx in self.policy_idx_to_env_list:
            self.policy_idx_to_env_list[policy_idx] = to_torch(self.policy_idx_to_env_list[policy_idx], dtype=torch.long, device=self.device)
        
        if self.use_multi_objs:
            self.envs_scale_list = torch.tensor(self.envs_scale_list, dtype=torch.float, device=self.device)
        
        if self.add_zrot_penalty:
            self.envs_obj_points = torch.stack(self.envs_obj_points, dim=0) # nn_envs x nn_obj_pts x 3 #
        
        # if self.hand_tracking:
        #     self.target_hand_indices = to_torch(self.target_hand_indices, dtype=torch.long, device=self.device)
        
        if self.real_to_sim_auto_tune:
            self.envs_hand_actor_rigid_body_mass_list = np.stack(self.envs_hand_actor_rigid_body_mass_list, axis=0)
            self.envs_hand_actor_rigid_body_inertia_values = np.stack(self.envs_hand_actor_rigid_body_inertia_values, axis=0)
            self.envs_hand_actor_rigid_body_friction_list = np.stack(self.envs_hand_actor_rigid_body_friction_list, axis=0)
        
        self.fingertip_handles = to_torch(self.fingertip_handles, dtype=torch.long, device=self.device)

    def adjust_gravity_force_coef(self, env_ids):
        if self.schedule_gravity_force:
            if 0 in list(env_ids):
                if self.schedule_gravity_force_step < self.schedule_gravity_force_warming_up_steps:
                    self.cur_gravity_force = self.gravity_force_min
                elif self.schedule_gravity_force_step < self.schedule_gravity_force_increasing_steps + self.schedule_gravity_force_warming_up_steps:
                    self.cur_gravity_force = self.gravity_force_min + (self.gravity_force_max - self.gravity_force_min) * float(self.schedule_gravity_force_step - self.schedule_gravity_force_warming_up_steps) / float(self.schedule_gravity_force_increasing_steps)
                else:
                    self.cur_gravity_force = self.gravity_force_max
                self.schedule_gravity_force_step += 1
                print(f"Debugging - gravity_force: {self.cur_gravity_force}, schedule_gravity_force_step: {self.schedule_gravity_force_step}")
    

    def adjust_rot_vel_coef(self, env_ids):
        
        if self.schedule_rot_vel_coef:
            if 0 in list(env_ids):
                if self.rot_vel_coef_step < self.schedule_rot_vel_warming_up_steps:
                    self.rot_vel_coef = self.rot_vel_coef_min
                elif self.rot_vel_coef_step < self.schedule_rot_vel_increasing_steps + self.schedule_rot_vel_warming_up_steps:
                    self.rot_vel_coef = self.rot_vel_coef_min + (self.rot_vel_coef_max - self.rot_vel_coef_min) * (self.rot_vel_coef_step - self.schedule_rot_vel_warming_up_steps) / (self.schedule_rot_vel_increasing_steps)
                else:
                    self.rot_vel_coef = self.rot_vel_coef_max
                
                self.rot_vel_coef_step += 1
                print(f"Debugging - rot_vel_coef: {self.rot_vel_coef}, rot_vel_coef_step: {self.rot_vel_coef_step}")

    def adjust_rotp_coef(self, env_ids):
        if self.add_rotp:
            if 0 in list(env_ids):
                if self.rotp_step < self.rotp_warmup_steps:
                    self.cur_rotp_coef = 0.0
                elif self.rotp_step < self.rotp_increasing_steps + self.rotp_warmup_steps:
                    self.cur_rotp_coef = self.rotp_coef * float(self.rotp_step - self.rotp_warmup_steps) / float(self.rotp_increasing_steps)
                else:
                    self.cur_rotp_coef = self.rotp_coef
                self.rotp_step += 1
                print(f"Debugging - rotp_coef: {self.cur_rotp_coef}, rotp_step: {self.rotp_step}")


    def try_save_delta_action_model(self, env_ids):
        if 0 in list(env_ids):
            self.compensator_reset_nn += 1
            if self.compensator_reset_nn % 10 == 0:
                file_path = os.path.join(self.config['env']['output_name'], 'compensator')
                os.makedirs(file_path, exist_ok=True)
                if self.action_compensator_w_full_hand or self.wm_per_joint_compensator_full_hand:
                    if self.train_action_compensator_w_real_wm_multi_compensator:
                        for compensator_idx in self.compensator_idx_to_delta_action_model_full_hand:
                            cur_compensator_dict = self.compensator_idx_to_delta_action_model_full_hand[compensator_idx].state_dict()
                            torch.save(cur_compensator_dict, os.path.join(file_path, f'full_hand_compensator_{compensator_idx}_resetn{self.compensator_reset_nn}.pth'))
                        compensator_weight_dict = self.compensator_weight_mlp.state_dict()
                        torch.save(compensator_weight_dict, os.path.join(file_path, f'compensator_weight_mlp_resetn{self.compensator_reset_nn}.pth'))
                    else:
                        full_hand_wm_state_dict = self.delta_action_model_full_hand.state_dict()
                        torch.save(full_hand_wm_state_dict, os.path.join(file_path, f'full_hand_compensator_resetn{self.compensator_reset_nn}.pth'))
                else:
                    for joint_idx in self.joint_idx_to_delta_action_model:
                        cur_joint_idx_dict = self.joint_idx_to_delta_action_model[joint_idx].state_dict()
                        torch.save(cur_joint_idx_dict, os.path.join(file_path, f'delta_action_model_{joint_idx}_resetn{self.compensator_reset_nn}.pth'))
                

    def reset_idx(self, env_ids):
        
        if self.real_to_sim_auto_tune:
            self.reset_idx_autotune(env_ids=env_ids)
            return
        
        if self.train_action_compensator_w_real_wm:
            self.try_save_delta_action_model(env_ids)
        self.adjust_rot_vel_coef(env_ids)
        self.adjust_gravity_force_coef(env_ids)
        self.adjust_rotp_coef(env_ids)
        
        self.tot_reset_nn = self.tot_reset_nn + 1
        
        if self.randomize_pd_gains:
            if self.preset_pd_gains:
                self.p_gain[env_ids] =  torch_rand_float(
                    (self.randomize_p_gain_upper - self.randomize_p_gain_lower) * (-0.5), (self.randomize_p_gain_upper - self.randomize_p_gain_lower) * (0.5), (len(env_ids), self.num_dofs), device=self.device).squeeze(1) + self.preset_pgains[env_ids]   # .unsqueeze(0)
                self.d_gain[env_ids] =  torch_rand_float(
                    (self.randomize_d_gain_upper - self.randomize_d_gain_lower) * (-0.5), (self.randomize_d_gain_upper - self.randomize_d_gain_lower) * (0.5), (len(env_ids), self.num_dofs), device=self.device).squeeze(1) + self.preset_dgains[env_ids] 
            else:
                self.p_gain[env_ids] = torch_rand_float(
                    self.randomize_p_gain_lower, self.randomize_p_gain_upper, (len(env_ids), self.num_dofs),
                    device=self.device).squeeze(1)
                self.d_gain[env_ids] = torch_rand_float(
                    self.randomize_d_gain_lower, self.randomize_d_gain_upper, (len(env_ids), self.num_dofs),
                    device=self.device).squeeze(1)
        

        # reset rigid body forces
        self.rb_forces[env_ids, :, :] = 0.0

        num_scales = len(self.randomize_scale_list)
        for n_s in range(num_scales):
            
            
            if self.use_multi_objs:
                s_ids = env_ids[(self.envs_scale_list[env_ids] == self.randomize_scale_list[n_s]).nonzero(as_tuple=False).squeeze(-1) ]
            else:
                s_ids = env_ids[(env_ids % num_scales == n_s).nonzero(as_tuple=False).squeeze(-1)]
            
            if len(s_ids) == 0:
                continue
            
            obj_scale = self.randomize_scale_list[n_s]
            scale_key = str(obj_scale)
            
            if self.seperate_inst_grasp_pose:
                sampled_pose = torch.zeros((len(s_ids), 16 + 7), device=self.device)
                for i_inst in range(self.nn_object_insts):
                    cur_inst_s_ids = (self.envs_obj_indices[s_ids] == i_inst).nonzero(as_tuple=False).squeeze(-1)
                    if len(cur_inst_s_ids) == 0:
                        continue
                    sampled_pose_idx = np.random.randint(self.saved_grasping_states[i_inst][scale_key].shape[0], size=len(cur_inst_s_ids))
                    sampled_pose[cur_inst_s_ids, :] = self.saved_grasping_states[i_inst][scale_key][sampled_pose_idx].clone()
            else:
                sampled_pose_idx = np.random.randint(self.saved_grasping_states[scale_key].shape[0], size=len(s_ids))
                sampled_pose = self.saved_grasping_states[scale_key][sampled_pose_idx].clone()
                
            # print(f"sampled_pose: {sampled_pose}")
            self.root_state_tensor[self.object_indices[s_ids], :7] = sampled_pose[:, 16:]
            self.root_state_tensor[self.object_indices[s_ids], 7:13] = 0
            
            if self.add_translation:
                self.root_state_tensor[self.object_indices[s_ids], 0:3] -= self.trans_dir_buf[s_ids, :3] * 0.00
            
            
            # set target root state tensor #
            #### Set target root state tensor ####
            self.target_root_state_tensor[self.object_indices[s_ids], :7] = sampled_pose[:, 16:].clone()
            # self.target_root_state_tensor[self.object_indices[s_ids], 2] += 0.01
            self.target_root_state_tensor[self.object_indices[s_ids], 2] += 0.05
            #### Set target root state tensor ####
            
            
            
            ### sampled goal grasping poses ###
            if self.grasp_to_grasp: 
                if self.seperate_inst_grasp_pose:
                    goal_grasp_pose = torch.zeros((len(s_ids), 16 + 7), device=self.device)
                    for i_inst in range(self.nn_object_insts):
                        cur_inst_s_ids = (self.envs_obj_indices[s_ids] == i_inst).nonzero(as_tuple=False).squeeze(-1)
                        if len(cur_inst_s_ids) == 0:
                            continue
                        sampled_goal_pose_idx = np.random.randint(self.saved_grasping_states[i_inst][scale_key].shape[0], size=len(cur_inst_s_ids))
                        goal_grasp_pose[cur_inst_s_ids, :] = self.saved_grasping_states[i_inst][scale_key][sampled_goal_pose_idx].clone()
                else:
                    sampled_goal_pose_idx = np.random.randint(self.saved_grasping_states[scale_key].shape[0], size=len(s_ids))
                    goal_grasp_pose = self.saved_grasping_states[scale_key][sampled_goal_pose_idx].clone()
                # self.goal_
                self.goal_object_pose[s_ids, :] = goal_grasp_pose[:, 16: ]
                self.goal_hand_pose[s_ids, :] = goal_grasp_pose[:, :16]
                
                # self.root_state_tensor[self.target_object_indices[s_ids], 3:7] = self.goal_object_pose[s_ids, 3:7].clone()
                        
            
            
            if self.omni_wrist_ornt or (self.hand_facing_dir == 'down' and self.grasp_cache_name in ['leap_change_g_dir']):
                rnd_float = torch_rand_float(-1.0, 1.0, (len(s_ids), 3), device=self.device)
                
                new_rnd_rot = randomize_rotation_rpy(rnd_float[:, 0], rnd_float[:, 1], rnd_float[:, 2], self.x_unit_tensor[s_ids], self.y_unit_tensor[s_ids], self.z_unit_tensor[s_ids])
                q_h = self.hand_pose_tsr[ 3:].unsqueeze(0).repeat(len(s_ids), 1).contiguous()
                q_h_new = quat_mul(new_rnd_rot, q_h)
                q_o = sampled_pose[:, 16 + 3: 16 + 7]
                q_o_new = quat_mul(new_rnd_rot, q_o)
                t_o = sampled_pose[:, 16: 16 + 3] # s_len x 3
                t_h = self.hand_pose_tsr[:3].unsqueeze(0).repeat(len(s_ids), 1).contiguous() # s_len x 3
                t_o_new = t_h + quat_apply(new_rnd_rot, t_o - t_h)
                
                
                
                if self.grasp_cache_name == 'leap_down':
                    t_o_new = sampled_pose[:, 16: 16 + 3]
                    q_o_new = sampled_pose[:, 16 + 3: 16 + 7]
                    
                    hand_rot_quat = gymapi.Quat.from_axis_angle(
                        gymapi.Vec3(0, 1, 0), np.pi / 2) * gymapi.Quat.from_axis_angle(gymapi.Vec3(1, 0, 0), np.pi / 2)
                    hand_rot_quat_tsr = torch.tensor([hand_rot_quat.x, hand_rot_quat.y, hand_rot_quat.z, hand_rot_quat.w], ).float().unsqueeze(0).repeat(len(s_ids), 1).to(self.device)
                    q_h_new = hand_rot_quat_tsr
                    
                    self.reset_z_threshold_tensor[s_ids] = 0.375 # t_o_new[:, 2] + delta_reset_z #
                else:
                    delta_reset_z = self.reset_z_threshold - sampled_pose[:, 18]
                    self.reset_z_threshold_tensor[s_ids] = t_o_new[:, 2] + delta_reset_z
                
                
                if self.omni_wrist_ornt and self.omni_wrist_horizontal_ornt_only and self.specified_wrist_ornt == 'palm_up':
                    t_o_new[..., 2] += 0.02
                    
                self.root_state_tensor[self.object_indices[s_ids], 3:7] = q_o_new
                self.root_state_tensor[self.object_indices[s_ids], :3] = t_o_new
                self.root_state_tensor[self.hand_indices[s_ids], 3:7] = q_h_new
                self.root_state_tensor[self.hand_indices[s_ids], :3] = t_h
                self.rnd_rot_tensor[s_ids, :] = new_rnd_rot
                
                # Set target root state tensor #
                self.target_root_state_tensor[self.object_indices[s_ids], 3:7] = q_o_new.clone()
                self.target_root_state_tensor[self.object_indices[s_ids], :3] = t_o_new.clone()
                # self.target_root_state_tensor[self.object_indices[s_ids], 2] += 0.01 # offset the z-axis by 0.01 
                self.target_root_state_tensor[self.object_indices[s_ids], 2] += 0.05
                self.target_root_state_tensor[self.hand_indices[s_ids], 3:7] = q_h_new.clone()
                self.target_root_state_tensor[self.hand_indices[s_ids], :3] = t_h.clone()

            if self.add_disturbances_to_init_state:
                
                ###### v5 disturbances --- testing for in-hand translation ######
                disturbing_lower_limits = [0.0, 0.0, 0.0]
                disturbing_upper_limits = [0.000001, 0.000001, 0.000001] 
                ###### v5 disturbances --- testing for in-hand translation ######
                
                
                rnd_disturbing_obj_delta_pos_x = torch_rand_float(disturbing_lower_limits[0], disturbing_upper_limits[0], (len(s_ids), 1), device=self.device)
                rnd_disturbing_obj_delta_pos_y = torch_rand_float(disturbing_lower_limits[1], disturbing_upper_limits[1], (len(s_ids), 1), device=self.device)
                rnd_disturbing_obj_delta_pos_z = torch_rand_float(disturbing_lower_limits[2], disturbing_upper_limits[2], (len(s_ids), 1), device=self.device)
                rnd_disturbing_obj_delta_pos = torch.cat((rnd_disturbing_obj_delta_pos_x, rnd_disturbing_obj_delta_pos_y, rnd_disturbing_obj_delta_pos_z), dim=1)
                ##### xyz-specific lower limit and upper limit #####
                
                rnd_select_sign = torch_rand_float(-1.0, 1.0, (len(s_ids), 3), device=self.device)
                
                rnd_select_sign[..., 2] = -1.0 
                
                rnd_disturbing_obj_delta_pos[rnd_select_sign < 0] *= (-1.0)
                self.root_state_tensor[self.object_indices[s_ids], :3] += rnd_disturbing_obj_delta_pos
            
            
            pos = sampled_pose[:, :16]
            self.allegro_hand_dof_pos[s_ids, :] = pos
            self.allegro_hand_dof_vel[s_ids, :] = 0
            self.prev_targets[s_ids, :self.num_allegro_hand_dofs] = pos
            self.cur_targets[s_ids, :self.num_allegro_hand_dofs] = pos
            self.init_pose_buf[s_ids, :] = pos.clone()
        
        
        
        if self.train_action_compensator:
            rnd_selected_env_idxes = torch.randint(0, self.real_replay_qtars.size(0), (len(env_ids), ), dtype=torch.long, device=self.device)
            self.envs_replay_qpos[env_ids, :] = batched_index_select(self.real_replay_qpos, rnd_selected_env_idxes, dim=0)  
            self.envs_replay_qtars[env_ids, :] = batched_index_select(self.real_replay_qtars, rnd_selected_env_idxes, dim=0)
            self.allegro_hand_dof_pos[env_ids, :] = self.envs_replay_qpos[env_ids, 0, :].clone()
            self.prev_targets[env_ids, :self.num_allegro_hand_dofs] = self.envs_replay_qtars[env_ids, 0, :].clone()
            self.cur_targets[env_ids, :self.num_allegro_hand_dofs] = self.envs_replay_qtars[env_ids, 0, :].clone()
            envs_init_obj_pose = batched_index_select(self.real_replay_init_obj_states, rnd_selected_env_idxes, dim=0) # nn_reset_envs x 7
            if self.action_compensator_w_obj: 
                self.root_state_tensor[self.object_indices[env_ids], :7] = envs_init_obj_pose.clone()
        
        if self.openloop_replay:
            self.allegro_hand_dof_pos[env_ids,:] = self.openloop_replay_src_states[env_ids, 0, :].clone()
            self.prev_targets[env_ids, :self.num_allegro_hand_dofs] = self.openloop_replay_src_states[env_ids, 0, :].clone()
            self.cur_targets[env_ids, :self.num_allegro_hand_dofs] = self.openloop_replay_src_states[env_ids, 0, :].clone()
        
        object_indices = torch.unique(self.object_indices[env_ids]).to(torch.int32)
        self.gym.set_actor_root_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self.root_state_tensor), gymtorch.unwrap_tensor(object_indices), len(object_indices))
        hand_indices = self.hand_indices[env_ids].to(torch.int32)
        if not self.torque_control:
            self.gym.set_dof_position_target_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self.prev_targets), gymtorch.unwrap_tensor(hand_indices), len(env_ids))
        self.gym.set_dof_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self.dof_state), gymtorch.unwrap_tensor(hand_indices), len(env_ids))
        
        if self.omni_wrist_ornt  or (self.hand_facing_dir == 'down' and self.grasp_cache_name in ['leap_change_g_dir']):
            self.gym.set_actor_root_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self.root_state_tensor), gymtorch.unwrap_tensor(hand_indices), len(env_ids))
            
        if self.add_aux_pose_guidance:
            if self.omni_wrist_ornt:
                rotated_rot_axis_buf = quat_apply(self.rnd_rot_tensor, self.rot_axis_buf)
                cur_obj_guiding_rot = quat_from_angle_axis(  self.guiding_delta_rot_radian * torch.ones((self.num_envs,), dtype=torch.float32).to(self.device) , rotated_rot_axis_buf)
            else:
                cur_obj_guiding_rot = quat_from_angle_axis(  self.guiding_delta_rot_radian * torch.ones((self.num_envs,), dtype=torch.float32).to(self.device) , self.rot_axis_buf)
            nex_guiding_pose = quat_mul(cur_obj_guiding_rot, self.object_rot) 
            self.guiding_pose[env_ids] = nex_guiding_pose[env_ids]
            self.upd_guiding_pose_steps[env_ids] = 1
        
        
        # randomize hand tracking targets
        if self.hand_tracking:
            # update the hand tracking targets; set the period count to zero #
            hand_pose_rand_floats = torch_rand_float(-1.0, 1.0, (len(env_ids), self.num_allegro_hand_dofs), device=self.device)
            target_hand_pose = self.allegro_hand_dof_pos[env_ids, :] + hand_pose_rand_floats # * 0.25
            target_hand_pose = tensor_clamp(target_hand_pose, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)
            self.hand_tracking_targets[env_ids, :self.num_allegro_hand_dofs] = target_hand_pose
            self.hand_tracking_period_count[env_ids] = 0
        
        
        if self.randomize_rot_dir:
            ori_rot_axis_buf = torch.zeros((len(env_ids), 3), device=self.device, dtype=torch.float32)
            ori_rot_axis_buf[..., 2] = 1.0
            rnd_rot_float = torch_rand_float(-1.0, 1.0, (len(env_ids), 3), device=self.device)
            rnd_rot = randomize_rotation_rpy(rnd_rot_float[:, 0], rnd_rot_float[:, 1], rnd_rot_float[:, 2], self.x_unit_tensor[env_ids], self.y_unit_tensor[env_ids], self.z_unit_tensor[env_ids])
            rnd_rot_axis_buf = quat_apply(rnd_rot, ori_rot_axis_buf)
            self.rot_axis_buf[env_ids, :] = rnd_rot_axis_buf.clone()
            
            if self.use_preset_rot_dir:
                self.rot_axis_buf[env_ids, :] = self.preset_rot_dir.clone().unsqueeze(0).repeat(len(env_ids), 1).contiguous()
        
            if self.evaluate and self.evaluate_goal_conditioned: # evaluate goal conditioned #
                # goal conditioned #
                rnd_float = torch_rand_float(-1.0, 1.0, (len(env_ids), 3), device=self.device)
                
                rnd_float[..., 0] = 0.0
                rnd_float[..., 1] = 0.0
                rnd_float[..., 2] = 0.35
                
                
                additional_rnd_rot = randomize_rotation_rpy(rnd_float[:, 0], rnd_float[:, 1], rnd_float[:, 2], self.x_unit_tensor[env_ids], self.y_unit_tensor[env_ids], self.z_unit_tensor[env_ids])
                cur_obj_rot_quat = self.root_state_tensor[self.object_indices[env_ids], 3: 7] # get the obj rot quat #
                rnd_target_obj_ornt = quat_mul(additional_rnd_rot, cur_obj_rot_quat) # obj rnd rot; additional rand rot #
                self.target_obj_pose_buf[env_ids, :] = rnd_target_obj_ornt.clone() # target obj ornt # # 
                ex, ey, ez = get_euler_xyz(additional_rnd_rot)
                additional_rnd_rot_euler = torch.stack([ ex, ey, ez ], dim=-1)
                rot_axis_from_rnd_rot = additional_rnd_rot_euler / torch.clamp(torch.norm(additional_rnd_rot_euler, p=2, dim=-1, keepdim=True), min=1e-6) # rot axis from rnd rot #
                # print(f"rot_axis_from_rnd_rot: {rot_axis_from_rnd_rot[0]}")
                self.rot_axis_buf[env_ids, :] = rot_axis_from_rnd_rot[:].clone() # 

            if self.train_goal_conditioned:
                rnd_float  = torch_rand_float(-1.0, 1.0, (len(env_ids), 3), device=self.device)
                additional_rnd_rot = randomize_rotation_rpy(rnd_float[:, 0], rnd_float[:, 1], rnd_float[:, 2], self.x_unit_tensor[env_ids], self.y_unit_tensor[env_ids], self.z_unit_tensor[env_ids])
                cur_obj_rot_quat = self.root_state_tensor[self.object_indices[env_ids], 3: 7] # get the obj rot quat #
                rnd_target_obj_ornt = quat_mul(additional_rnd_rot, cur_obj_rot_quat) # obj rnd rot; additional rand rot #
                self.target_obj_pose_buf[env_ids, :] = rnd_target_obj_ornt.clone() # target obj ornt # # 
                ex, ey, ez = get_euler_xyz(additional_rnd_rot)
                additional_rnd_rot_euler = torch.stack([ ex, ey, ez ], dim=-1)
                rot_axis_from_rnd_rot = additional_rnd_rot_euler / torch.clamp(torch.norm(additional_rnd_rot_euler, p=2, dim=-1, keepdim=True), min=1e-6)
                self.rot_axis_buf[env_ids, :] = rot_axis_from_rnd_rot[:].clone()
                
        self.init_root_state_tensor[self.object_indices[env_ids], :7] = self.root_state_tensor[self.object_indices[env_ids], :7].clone()

        self.obs_buf_lag_history_qpos[env_ids, :, :] = self.allegro_hand_dof_pos[env_ids].unsqueeze(1).repeat(1, self.lagging_obs_length, 1).contiguous()
        self.obs_buf_lag_history_qtars[env_ids, :, :] = self.allegro_hand_dof_pos[env_ids].unsqueeze(1).repeat(1, self.lagging_obs_length, 1).contiguous()
        self.obs_buf_lag_history_compensated_qtars[env_ids, :, :] = self.allegro_hand_dof_pos[env_ids].unsqueeze(1).repeat(1, self.lagging_obs_length, 1).contiguous()

        if self.random_start:
            # 375
            self.progress_buf[env_ids] = torch.randint(0, 375, (len(env_ids), ), device=self.device).long()
            self.random_start = False
        else:
            self.progress_buf[env_ids] = 0
        self.obs_buf[env_ids] = 0
        self.rb_forces[env_ids] = 0
        self.priv_info_buf[env_ids, 0:3] = 0
        self.proprio_hist_buf[env_ids] = 0
        self.at_reset_buf[env_ids] = 1


    def reset_idx_autotune(self, env_ids):
        self.rb_forces[env_ids, :, :] = 0.0 
        # self.testing_traj_idx += 1
        if self.testing_traj_idx >= self.auto_tune_states.size(0):
            
            nn_env_chunks = 100
            nn_chunks = self.num_envs // nn_env_chunks
            for i_chunk in range(nn_chunks):
                cur_st_env = i_chunk * nn_env_chunks
                cur_ed_env = (i_chunk + 1) * nn_env_chunks
                
                
                if (not self.preset_pd_gains):
                    cur_chunk_sv_dict = {
                        'tested_envs_states': self.tested_envs_states[cur_st_env: cur_ed_env, :].detach().cpu().numpy(),
                        'tested_envs_obj_states': self.tested_envs_obj_states[cur_st_env: cur_ed_env, :].detach().cpu().numpy(),
                        'pgains': self.envs_dofs_rnd_pgains[cur_st_env: cur_ed_env, :].detach().cpu().numpy(),
                        'dgains': self.envs_dofs_rnd_dgains[cur_st_env: cur_ed_env, :].detach().cpu().numpy(),
                        'hand_rigid_body_masses': self.envs_hand_actor_rigid_body_mass_list[cur_st_env: cur_ed_env, :],
                        'hand_rigid_body_inertia_values': self.envs_hand_actor_rigid_body_inertia_values[cur_st_env: cur_ed_env, :],
                        'hand_rigid_body_friction': self.envs_hand_actor_rigid_body_friction_list[cur_st_env: cur_ed_env, :],
                        # 'minn_diff': cur_minn_diff_states
                    }
                else:
                    diff_cur_seq_states_w_real_states = torch.sum(
                        (self.tested_envs_states[cur_st_env: cur_ed_env, :] - self.auto_tune_states.unsqueeze(0)) ** 2, dim=-1
                    )
                    # sum of the model and mean of the model # # nn_envs x nn_seqs x nn_ts x xxx
                    diff_cur_seq_states_w_real_states = diff_cur_seq_states_w_real_states.mean(dim=2).mean(dim=1)
                    # cur_minn_diff_states_idx = np.argmin(diff_cur_seq_states_w_real_states)
                    cur_minn_diff_states = diff_cur_seq_states_w_real_states.min().item()
                    cur_chunk_sv_dict = {
                        'tested_envs_states': self.tested_envs_states[cur_st_env: cur_ed_env, :].detach().cpu().numpy(),
                        'tested_envs_obj_states': self.tested_envs_obj_states[cur_st_env: cur_ed_env, :].detach().cpu().numpy(),
                        'minn_diff': cur_minn_diff_states
                    }
                
                ####### chunk sv dict fn #######
                cur_chunk_sv_dict_fn = f"./cache/auto_tune_states_mujoco_to_gym_cuboidthin_{i_chunk}.npy"
                np.save(cur_chunk_sv_dict_fn, cur_chunk_sv_dict)
                print(f"Debugging - chunk {i_chunk} cur_chunk_sv_dict_fn: {cur_chunk_sv_dict_fn}")
            
            exit(0)
        
        cur_testing_traj_st_states = self.auto_tune_states[self.testing_traj_idx, 0, :]
        envs_st_states = cur_testing_traj_st_states.unsqueeze(0).repeat(env_ids.size(0), 1).contiguous() # nn_envs x 16
        self.allegro_hand_dof_pos[env_ids, :] = envs_st_states[:, :self.num_allegro_hand_dofs]
        self.allegro_hand_dof_vel[env_ids, :] = 0.0
        
        self.prev_targets[env_ids, :self.num_allegro_hand_dofs] = envs_st_states[:, :self.num_allegro_hand_dofs]
        self.cur_targets[env_ids, :self.num_allegro_hand_dofs] = envs_st_states[:, :self.num_allegro_hand_dofs]
        
        hand_indices = self.hand_indices[env_ids].to(torch.int32)
        self.gym.set_dof_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self.dof_state), gymtorch.unwrap_tensor(hand_indices), len(env_ids))
        
        if self.real_to_sim_auto_tune_w_obj:
            print(f"Reseting with object iniit state")
            cur_testing_traj_st_obj_states = self.auto_tune_init_obj_poses[self.testing_traj_idx, :]
            self.root_state_tensor[self.object_indices[env_ids], :7] = cur_testing_traj_st_obj_states.contiguous().unsqueeze(0).repeat(env_ids.size(0), 1).contiguous()
            self.root_state_tensor[self.object_indices[env_ids], 7:] = 0.0
            object_indices = torch.unique(self.object_indices[env_ids]).to(torch.int32)
            self.gym.set_actor_root_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self.root_state_tensor), gymtorch.unwrap_tensor(object_indices), len(object_indices))
        
        self.testing_traj_ts = 0
    
    
    def train_multi_delta_action_model(self, ):
        
        tot_loss_diff = []
        
        cur_qpos_buf = self.obs_buf_lag_history_qpos[:, - self.compensator_history_length:, ].clone() # nn_envs x nn_history_length x 16
        # cur_qpos_buf = (torch.rand(cur_qpos_buf.shape).to(self.device) * 2.0 - 1.0) * self.joint_noise_scale + cur_qpos_buf
        cur_qpos_buf = unscale(cur_qpos_buf, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)
        cur_target = self.cur_targets.clone()
        
        if self.compensator_history_length == 1:
            cur_qtars_buf = cur_target
        else:
            cur_qtars_buf = self.obs_buf_lag_history_qtars[:, - self.compensator_history_length + 1: , ].clone()
            cur_qtars_buf = torch.cat([cur_qtars_buf, cur_target.unsqueeze(1)], dim=1)
        cur_qtars_buf = unscale(cur_qtars_buf, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)
        cur_qpos_buf = cur_qpos_buf.contiguous().view(cur_qpos_buf.size(0), -1).contiguous()
        cur_qtars_buf = cur_qtars_buf.contiguous().view(cur_qtars_buf.size(0), -1).contiguous()
        
        compensator_selector_input = torch.cat(
            [ cur_qpos_buf, cur_qtars_buf ], dim=-1 # nn_envs x (history latent features)
        )
        compensator_selector_output = self.compensator_weight_mlp(compensator_selector_input)
        compensator_selector_output = torch.softmax(compensator_selector_output, dim=-1) # nn_envs x 4 --- which specifies the weight for each compensator # 
        delta_action_input_dict = {
            'state': cur_qpos_buf, 'action': cur_qtars_buf
        }
        
        tot_out_delta_action = []
        
        for compensator_idx in self.compensator_idx_to_delta_action_model_full_hand:
            self.compensator_idx_to_delta_action_model_full_hand[compensator_idx].train()
            cur_out_delta_action = self.compensator_idx_to_delta_action_model_full_hand[compensator_idx](delta_action_input_dict)
            cur_out_delta_action = torch.clamp(cur_out_delta_action, -1.0, 1.0)
            tot_out_delta_action.append(cur_out_delta_action)
        tot_out_delta_action = torch.stack(tot_out_delta_action, dim=1) # nn_envs x (nn_compensators) x nn_hand_dof
        out_delta_action = tot_out_delta_action * compensator_selector_output.unsqueeze(-1) # dleta action #
        out_delta_action = out_delta_action.sum(dim=1) # nn_envs x nn_hand_dof
        
        cur_wm_hist_qpos = self.obs_buf_lag_history_qpos[:, - self.full_hand_wm_history_length: , ].clone()
        cur_compensated_target = self.cur_targets + self.delta_action_scale * out_delta_action
        
        cur_wm_hist_qpos_flatten = cur_wm_hist_qpos.contiguous().view(cur_wm_hist_qpos.size(0), -1).contiguous()
        
        self.cur_fullhand_compensated_target = cur_compensated_target.detach().clone()
        
        
        ######## fullhand compensator target; wm history qpos, qtars ########
        if self.wm_history_length == 1:
            hist_qtars = self.cur_fullhand_compensated_target.clone()
        else:
            hist_qtars = torch.cat(
                [ self.obs_buf_lag_history_qtars[:, -self.wm_history_length + 1: ], self.cur_fullhand_compensated_target.unsqueeze(1) ], dim=1
            ) # hist_qtars -- nn_envs x wm_history_length x nn_hand_dof
        wm_hist_qtars_flatten = hist_qtars.contiguous().view(hist_qtars.size(0), -1).contiguous() # nn_envs x (16 * 2)
        concat_hist_qpos_w_hist_qtars = torch.cat( [ cur_wm_hist_qpos_flatten, wm_hist_qtars_flatten  ], dim=-1 ) # nn_envs x (32 * 2)
        wm_selector_output = self.wm_weight_mlp(concat_hist_qpos_w_hist_qtars) 
        wm_selector_output = torch.softmax(wm_selector_output, dim=-1) # weights of differnet world model
        
        tot_loss_diff_backward = []
        
        for i_j, joint_idx in enumerate(self.wm_pred_joint_idxes):
            cur_joint_tot_wm_output = []
            
            ###### construct wm input ######
            cur_jt_wm_hist_qpos = self.obs_buf_lag_history_qpos[:, - self.wm_history_length:, joint_idx].clone()
            cur_jt_compensated_target = self.cur_fullhand_compensated_target[:, joint_idx].unsqueeze(1)
            # self.cur_targets[:, joint_idx].unsqueeze(1) + self.delta_action_scale * cur_out_delta_action
            # cur_jt_compensated_target = self.cur_targets[:, joint_idx].unsqueeze(1) +  self.compensating_targets * self.delta_action_scale
            if self.wm_history_length == 1:
                cur_jt_wm_hist_qtars = cur_jt_compensated_target
            else:
                cur_jt_wm_hist_qtars = torch.cat(
                    [ self.obs_buf_lag_history_qtars[:, - self.wm_history_length + 1: , joint_idx],  cur_jt_compensated_target], dim=1
                )
            cur_jt_wm_hist_qpos = unscale(cur_jt_wm_hist_qpos, self.allegro_hand_dof_lower_limits[[joint_idx]], self.allegro_hand_dof_upper_limits[[joint_idx]])
            cur_jt_wm_hist_qtars = unscale(cur_jt_wm_hist_qtars, self.allegro_hand_dof_lower_limits[[joint_idx]], self.allegro_hand_dof_upper_limits[[joint_idx]])
            
            cur_jt_wm_hist_qpos = cur_jt_wm_hist_qpos.contiguous().view(cur_jt_wm_hist_qpos.size(0), -1).contiguous()
            cur_jt_wm_hist_qtars = cur_jt_wm_hist_qtars.contiguous().view(cur_jt_wm_hist_qtars.size(0), -1).contiguous()
            
            cur_jt_wm_input_dict = {
                'state': cur_jt_wm_hist_qpos,
                'action': cur_jt_wm_hist_qtars
            }
            
            for compensator_idx in self.compensator_idx_to_joint_idx_to_wm:
                self.compensator_idx_to_joint_idx_to_wm[compensator_idx][joint_idx].eval()
                cur_jt_pred_output = self.compensator_idx_to_joint_idx_to_wm[compensator_idx][joint_idx](cur_jt_wm_input_dict)
                cur_jt_pred_output = scale(cur_jt_pred_output, self.allegro_hand_dof_lower_limits[[joint_idx]], self.allegro_hand_dof_upper_limits[[joint_idx]])
            
                cur_joint_tot_wm_output.append(cur_jt_pred_output)
            cur_joint_tot_wm_output = torch.stack(cur_joint_tot_wm_output, dim=1) # nn_envs x nn_compensators x 1 #
            cur_joint_tot_wm_output = cur_joint_tot_wm_output * wm_selector_output.unsqueeze(-1) 
            cur_joint_tot_wm_output = cur_joint_tot_wm_output.sum(dim=1) # nn_envs x 1 #
            
            loss_diff = torch.mean((cur_joint_tot_wm_output - self.allegro_hand_dof_pos[..., [joint_idx]]) ** 2)
            
            tot_loss_diff_backward.append(loss_diff)
            
            if self.progress_buf[0].item() % 8 == 0:
                diff_abs = torch.mean(torch.abs(cur_joint_tot_wm_output - self.allegro_hand_dof_pos[..., [joint_idx]]))
                diff_abs_pre = torch.abs(self.real_wm_pred_next_state[..., [i_j]] - self.allegro_hand_dof_pos[..., [joint_idx]])
                # print(f"diff_abs_pre:{ diff_abs_pre.size()}, avg: {diff_abs_pre.mean(dim=0)}")
                diff_abs_pre = diff_abs_pre.mean(dim=0)
                cur_compensated_target_scale = torch.mean(torch.abs(out_delta_action[:, joint_idx] * self.delta_action_scale))
                # diff_abs_ncompensated = torch.mean(torch.abs(cur_jt_perd_output_ncompensated - self.allegro_hand_dof_pos[..., [joint_idx]]))
                print(f"diff_abs: {diff_abs.item()}, diff_abs_pre: {diff_abs_pre.item()}; cur_compensated_target_scale: {cur_compensated_target_scale.item()}")
                
        
        tot_loss_diff_backward = sum(tot_loss_diff_backward)
        
        # wm_weight_mlp_optimizer, compensator_weight_mlp_optimizer # 
        self.wm_weight_mlp_optimizer.zero_grad()
        self.compensator_weight_mlp_optimizer.zero_grad()
        for compensator_idx in self.compensator_idx_to_delta_action_model_full_hand_optimizer:
            self.compensator_idx_to_delta_action_model_full_hand_optimizer[compensator_idx].zero_grad()
        # self.delta_action_model_full_hand_optimizer.zero_grad()
        
        tot_loss_diff_backward.backward()
        
        self.wm_weight_mlp_optimizer.step()
        self.compensator_weight_mlp_optimizer.step()
        for compensator_idx in self.compensator_idx_to_delta_action_model_full_hand_optimizer:
            self.compensator_idx_to_delta_action_model_full_hand_optimizer[compensator_idx].step()
        # self.delta_action_model_full_hand_optimizer.step()
        
        tot_loss_diff = tot_loss_diff_backward.detach().item()
    
        
        
        return tot_loss_diff
    
    
    def train_delta_action_model(self, ):
        
        tot_loss_diff = []
        
        # train action compensator model #
        if self.wm_per_joint_compensator_full_hand:
            self.delta_action_model_full_hand.train()
            
            cur_qpos_buf = self.obs_buf_lag_history_qpos[:, - self.compensator_history_length:, ].clone() # nn_envs x nn_history_length x 16
            # cur_qpos_buf = (torch.rand(cur_qpos_buf.shape).to(self.device) * 2.0 - 1.0) * self.joint_noise_scale + cur_qpos_buf
            cur_qpos_buf = unscale(cur_qpos_buf, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)
            cur_target = self.cur_targets.clone()
            
            if self.compensator_history_length == 1:
                cur_qtars_buf = cur_target
            else:
                cur_qtars_buf  = self.obs_buf_lag_history_qtars[:, - self.compensator_history_length + 1: , ].clone()
                cur_qtars_buf = torch.cat([cur_qtars_buf, cur_target.unsqueeze(1)], dim=1)
            cur_qtars_buf = unscale(cur_qtars_buf, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)
            cur_qpos_buf = cur_qpos_buf.contiguous().view(cur_qpos_buf.size(0), -1).contiguous()
            cur_qtars_buf = cur_qtars_buf.contiguous().view(cur_qtars_buf.size(0), -1).contiguous()
            
            # delta action input dict #
            delta_action_input_dict = {
                'state': cur_qpos_buf, 'action': cur_qtars_buf
            }
            cur_out_delta_action = self.delta_action_model_full_hand(delta_action_input_dict)
            cur_out_delta_action = torch.clamp(cur_out_delta_action, -1.0, 1.0)
            
            # cur_compensated_target = self.cur_targets.clone()
            
            cur_wm_hist_qpos = self.obs_buf_lag_history_qpos[:, - self.full_hand_wm_history_length: , ].clone()
            cur_compensated_target = self.cur_targets + self.delta_action_scale * cur_out_delta_action
            
            self.cur_fullhand_compensated_target = cur_compensated_target # .detach().clone()
            
            tot_loss_diff_backward = []
            tot_pred_joint_state = []
            
            for i_j, joint_idx in enumerate(self.wm_pred_joint_idxes):
                self.joint_idx_to_wm[joint_idx].eval()
                
                ###### construct wm input ######
                cur_jt_wm_hist_qpos = self.obs_buf_lag_history_qpos[:, -self.wm_history_length:, joint_idx].clone()
                cur_jt_compensated_target = self.cur_fullhand_compensated_target[:, joint_idx].unsqueeze(1)
                # self.cur_targets[:, joint_idx].unsqueeze(1) + self.delta_action_scale * cur_out_delta_action
                # cur_jt_compensated_target = self.cur_targets[:, joint_idx].unsqueeze(1) +  self.compensating_targets * self.delta_action_scale
                if self.wm_history_length == 1:
                    cur_jt_wm_hist_qtars = cur_jt_compensated_target
                else:
                    cur_jt_wm_hist_qtars = torch.cat(
                        [ self.obs_buf_lag_history_qtars[:, - self.wm_history_length + 1: , joint_idx],  cur_jt_compensated_target], dim=1
                    )
                cur_jt_wm_hist_qpos = unscale(cur_jt_wm_hist_qpos, self.allegro_hand_dof_lower_limits[[joint_idx]], self.allegro_hand_dof_upper_limits[[joint_idx]])
                cur_jt_wm_hist_qtars = unscale(cur_jt_wm_hist_qtars, self.allegro_hand_dof_lower_limits[[joint_idx]], self.allegro_hand_dof_upper_limits[[joint_idx]])
                
                cur_jt_wm_hist_qpos = cur_jt_wm_hist_qpos.contiguous().view(cur_jt_wm_hist_qpos.size(0), -1).contiguous()
                cur_jt_wm_hist_qtars = cur_jt_wm_hist_qtars.contiguous().view(cur_jt_wm_hist_qtars.size(0), -1).contiguous()
                
                if self.invdyn_add_nearing_neighbour:
                    if joint_idx % 4 == 0:
                        bf_joint_idx = joint_idx
                    else:
                        bf_joint_idx = joint_idx - 1
                    if (joint_idx + 1) % 4 == 0:
                        af_joint_idx = joint_idx
                    else:
                        af_joint_idx = joint_idx + 1
                    cur_bf_joint_state = self.obs_buf_lag_history_qpos[:, -1, [bf_joint_idx]].clone()
                    cur_af_joint_state = self.obs_buf_lag_history_qpos[:, -1, [af_joint_idx]].clone()
                    cur_bf_joint_action =  self.cur_fullhand_compensated_target[:, [bf_joint_idx]] # .clone()
                    cur_af_joint_action = self.cur_fullhand_compensated_target[:, [af_joint_idx]]
                    
                    cur_bf_joint_state = unscale(cur_bf_joint_state, self.allegro_hand_dof_lower_limits[[bf_joint_idx]], self.allegro_hand_dof_upper_limits[[bf_joint_idx]])
                    cur_af_joint_state = unscale(cur_af_joint_state, self.allegro_hand_dof_lower_limits[[af_joint_idx]], self.allegro_hand_dof_upper_limits[[af_joint_idx]])
                    cur_bf_joint_action = unscale(cur_bf_joint_action, self.allegro_hand_dof_lower_limits[[bf_joint_idx]], self.allegro_hand_dof_upper_limits[[bf_joint_idx]])
                    cur_af_joint_action = unscale(cur_af_joint_action, self.allegro_hand_dof_lower_limits[[af_joint_idx]], self.allegro_hand_dof_upper_limits[[af_joint_idx]])
                    
                    cur_jt_wm_hist_qpos = torch.cat(
                        [ cur_jt_wm_hist_qpos, cur_bf_joint_state, cur_af_joint_state ], dim=-1
                    )
                    cur_jt_wm_hist_qtars = torch.cat(
                        [ cur_jt_wm_hist_qtars, cur_bf_joint_action, cur_af_joint_action ], dim=-1
                    )
                    
                cur_jt_wm_input_dict = {
                    'state': cur_jt_wm_hist_qpos,
                    'action': cur_jt_wm_hist_qtars
                }
                cur_jt_pred_output = self.joint_idx_to_wm[joint_idx](cur_jt_wm_input_dict)
                
                cur_jt_pred_output = scale(cur_jt_pred_output, self.allegro_hand_dof_lower_limits[[joint_idx]], self.allegro_hand_dof_upper_limits[[joint_idx]])
                
                loss_diff = torch.mean((cur_jt_pred_output - self.allegro_hand_dof_pos[..., [joint_idx]]) ** 2)
                
                tot_loss_diff_backward.append(loss_diff)
                
                
                if self.progress_buf[0].item() % 8 == 0:
                    # diff_abs = torch.mean(torch.abs(cur_jt_pred_output - self.allegro_hand_dof_pos[..., [joint_idx]]))
                    diff_abs = torch.mean(torch.abs(cur_jt_pred_output - self.allegro_hand_dof_pos[..., [joint_idx]]) ** 2)
                    diff_abs_pre = torch.abs(self.real_wm_pred_next_state[..., [i_j]] - self.allegro_hand_dof_pos[..., [joint_idx]])
                    # print(f"diff_abs_pre:{ diff_abs_pre.size()}, avg: {diff_abs_pre.mean(dim=0)}")
                    diff_abs_pre = diff_abs_pre.mean(dim=0)
                    cur_compensated_target_scale = torch.mean(torch.abs(cur_out_delta_action[:, joint_idx] * self.delta_action_scale))
                    # diff_abs_ncompensated = torch.mean(torch.abs(cur_jt_perd_output_ncompensated - self.allegro_hand_dof_pos[..., [joint_idx]]))
                    print(f"diff_abs: {diff_abs.item()}, diff_abs_pre: {diff_abs_pre.item()}; cur_compensated_target_scale: {cur_compensated_target_scale.item()}")
                    # wandb.log({
                    #     f"diff_abs_jt{joint_idx}": diff_abs.item(),
                    #     f"cur_compensated_target_scale_jt{joint_idx}": cur_compensated_target_scale.item() #
                    # })
                    ##### try to save the wm predicted outputs #####
                    
                    if joint_idx not in self.joint_idx_to_wm_pred_delta_abs:
                        self.joint_idx_to_wm_pred_delta_abs[joint_idx] = []
                        self.joint_idx_to_delta_action[joint_idx] = []
                    self.joint_idx_to_wm_pred_delta_abs[joint_idx].append(diff_abs.item())
                    self.joint_idx_to_delta_action[joint_idx].append(cur_compensated_target_scale.item())
                    
                
            
            tot_loss_diff_backward = sum(tot_loss_diff_backward)

            self.delta_action_model_full_hand_optimizer.zero_grad()
            tot_loss_diff_backward.backward()
            
 
            self.delta_action_model_full_hand_optimizer.step()
            
            tot_loss_diff = tot_loss_diff_backward.detach().item()
        
        
        elif self.action_compensator_w_full_hand:
            self.full_hand_wm.eval()
            self.delta_action_model_full_hand.train()
            
            # 
            cur_qpos_buf = self.obs_buf_lag_history_qpos[:, - self.compensator_history_length:, ].clone() # nn_envs x nn_history_length x 16
            # cur_qpos_buf = (torch.rand(cur_qpos_buf.shape).to(self.device) * 2.0 - 1.0) * self.joint_noise_scale + cur_qpos_buf
            cur_qpos_buf = unscale(cur_qpos_buf, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)
            cur_target = self.cur_targets.clone()
            
            if self.compensator_history_length == 1:
                cur_qtars_buf = cur_target
            else:
                cur_qtars_buf  = self.obs_buf_lag_history_qtars[:, - self.compensator_history_length + 1: , ].clone()
                cur_qtars_buf = torch.cat([cur_qtars_buf, cur_target.unsqueeze(1)], dim=1)
            cur_qtars_buf = unscale(cur_qtars_buf, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)
            cur_qpos_buf = cur_qpos_buf.contiguous().view(cur_qpos_buf.size(0), -1).contiguous()
            cur_qtars_buf = cur_qtars_buf.contiguous().view(cur_qtars_buf.size(0), -1).contiguous()
            
            delta_action_input_dict = {
                'state': cur_qpos_buf, 'action': cur_qtars_buf
            }
            cur_out_delta_action = self.delta_action_model_full_hand(delta_action_input_dict)
            cur_out_delta_action = torch.clamp(cur_out_delta_action, -1.0, 1.0)
            
            # cur_compensated_target = self.cur_targets.clone()
            
            cur_wm_hist_qpos = self.obs_buf_lag_history_qpos[:, - self.full_hand_wm_history_length: , ].clone()
            cur_compensated_target = self.cur_targets + self.delta_action_scale * cur_out_delta_action
            
            self.cur_fullhand_compensated_target = cur_compensated_target.detach().clone()
            
            # TODO: add a obs_buf_lag_history_compensated_qtars
            cur_wm_hist_qtars = self.obs_buf_lag_history_compensated_qtars[:, -self.full_hand_wm_history_length:, ].clone()
            cur_wm_hist_qpos = cur_wm_hist_qpos.contiguous().view(cur_wm_hist_qpos.size(0), -1).contiguous()
            cur_wm_hist_qtars = cur_wm_hist_qtars.contiguous().view(cur_wm_hist_qtars.size(0), -1).contiguous() #j
            cur_wm_state = torch.cat([cur_wm_hist_qpos, cur_wm_hist_qtars], dim=-1) # nn_envs x ()
            
            # cur pred nex state -- use the full hand wm #
            cur_pred_nex_state = self.full_hand_wm(cur_wm_state, cur_compensated_target, history_extrin=None)
            
            # cur_pred_nex_state = scale(cur_pred_nex_state, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)
            
            loss_diff = torch.sum((cur_pred_nex_state - self.allegro_hand_dof_pos) ** 2, dim=-1).mean()
            
            if self.progress_buf[0].item() % 8 == 0:
                diff_abs = torch.mean(torch.abs(cur_pred_nex_state - self.allegro_hand_dof_pos), dim=0)
                compenated_target_scale = torch.mean(torch.abs(self.delta_action_scale * cur_out_delta_action), dim=0)
                print(f"diff_abs: {diff_abs.detach().cpu().tolist()}")
                print(f"compenstaed_target_scale: {compenated_target_scale.detach().cpu().tolist()}")
            
            self.delta_action_model_full_hand_optimizer.zero_grad()
            loss_diff.backward()
            self.delta_action_model_full_hand_optimizer.step()
            
            tot_loss_diff = loss_diff.detach().item()
        else:
            tot_loss_diff_backward = []
            tot_pred_joint_state = []
            
            for i_j, joint_idx in enumerate(self.wm_pred_joint_idxes):
                self.joint_idx_to_delta_action_model[joint_idx].train()
                self.joint_idx_to_wm[joint_idx].eval()
                
                cur_jt_qpos_buf = self.obs_buf_lag_history_qpos[:, - self.compensator_history_length: , joint_idx].clone()
                cur_jt_qpos_buf = (torch.rand(cur_jt_qpos_buf.shape).to(self.device) * 2.0 - 1.0) * self.joint_noise_scale + cur_jt_qpos_buf
                # cur_jt_qtars_buf = self.obs_buf_lag_history_qtars[:, - self.compensator_history_length: , joint_idx].clone()
                cur_jt_qpos_buf = unscale(
                    cur_jt_qpos_buf, self.allegro_hand_dof_lower_limits[[joint_idx]], self.allegro_hand_dof_upper_limits[[joint_idx]]
                )
                cur_jt_target = self.cur_targets[:, joint_idx]
                
                # cur_jt_qpos_qtars_buf = torch.stack([ cur_jt_qpos_buf, cur_jt_qtars_buf], dim=-1) # nn_envs x nn_ts x 2
                # cur_jt_qpos_qtars_buf  = cur_jt_qpos_qtars_buf.contiguous().view(cur_jt_qpos_qtars_buf.size(0), -1).contiguous()
                # cur_jt_compensator_input = torch.cat([cur_jt_qpos_qtars_buf, cur_jt_target.unsqueeze(-1)], dim=-1) # nn_envs x (input _ dim)
                if self.compensator_history_length == 1:
                    cur_jt_qtars_buf = cur_jt_target
                else:
                    cur_jt_qtars_buf = self.obs_buf_lag_history_qtars[:, - self.compensator_history_length + 1: , joint_idx].clone()
                    cur_jt_qtars_buf = torch.cat([cur_jt_qtars_buf, cur_jt_target.unsqueeze(-1)], dim=1) # nn_envs x nn_hist_ts
                cur_jt_qtars_buf = unscale(cur_jt_qtars_buf, self.allegro_hand_dof_lower_limits[[joint_idx]], self.allegro_hand_dof_upper_limits[[joint_idx]]) # joint idx and xxxx # nn_envs 
                cur_jt_qpos_buf = cur_jt_qpos_buf.contiguous().view(cur_jt_qpos_buf.size(0), -1).contiguous()
                cur_jt_qtars_buf = cur_jt_qtars_buf.contiguous().view(cur_jt_qtars_buf.size(0), -1).contiguous()
                
                
                delta_action_input_dict = {
                    'state': cur_jt_qpos_buf,
                    'action': cur_jt_qtars_buf
                }
                
                cur_out_delta_action = self.joint_idx_to_delta_action_model[joint_idx](delta_action_input_dict)
                cur_out_delta_action = torch.clamp(cur_out_delta_action, -1.0, 1.0) #
                
                
                # cur_out_delta_action = self.compensating_targets # * self.delta_action_scale
                
                cur_jt_compensated_target = self.cur_targets.clone()
                
                ###### construct wm input ###### # wm -- history wm and the hisory #
                cur_jt_wm_hist_qpos = self.obs_buf_lag_history_qpos[:, - self.wm_history_length:, joint_idx].clone()
                cur_jt_compensated_target = self.cur_targets[:, joint_idx].unsqueeze(1) + self.delta_action_scale * cur_out_delta_action
                # cur_jt_compensated_target = self.cur_targets[:, joint_idx].unsqueeze(1) +  self.compensating_targets * self.delta_action_scale
                if self.wm_history_length == 1:
                    cur_jt_wm_hist_qtars = cur_jt_compensated_target
                else:
                    cur_jt_wm_hist_qtars = torch.cat(
                        [ self.obs_buf_lag_history_qtars[:, - self.wm_history_length + 1: , joint_idx],  cur_jt_compensated_target], dim=1
                    )
                cur_jt_wm_hist_qpos = unscale(cur_jt_wm_hist_qpos, self.allegro_hand_dof_lower_limits[[joint_idx]], self.allegro_hand_dof_upper_limits[[joint_idx]])
                cur_jt_wm_hist_qtars = unscale(cur_jt_wm_hist_qtars, self.allegro_hand_dof_lower_limits[[joint_idx]], self.allegro_hand_dof_upper_limits[[joint_idx]])
                
                cur_jt_wm_hist_qpos = cur_jt_wm_hist_qpos.contiguous().view(cur_jt_wm_hist_qpos.size(0), -1).contiguous()
                cur_jt_wm_hist_qtars = cur_jt_wm_hist_qtars.contiguous().view(cur_jt_wm_hist_qtars.size(0), -1).contiguous()
                
                cur_jt_wm_input_dict = {
                    'state': cur_jt_wm_hist_qpos,
                    'action': cur_jt_wm_hist_qtars
                }
                cur_jt_pred_output = self.joint_idx_to_wm[joint_idx](cur_jt_wm_input_dict)
                # cur_jt_pred_output = self.joint_idx_to_wm[joint_idx](self.joint_idx_to_delta_action_model_input[joint_idx])
                
                
                cur_jt_pred_output = scale(cur_jt_pred_output, self.allegro_hand_dof_lower_limits[[joint_idx]], self.allegro_hand_dof_upper_limits[[joint_idx]])
                
                
                
                with torch.no_grad():
                    if self.wm_history_length == 1:
                        cur_jt_wm_hist_qtars_ncompensated = self.cur_targets[:, joint_idx].unsqueeze(1).clone()
                    else:
                        cur_jt_wm_hist_qtars_ncompensated = torch.cat(
                            [ self.obs_buf_lag_history_qtars[:, - self.wm_history_length + 1: , joint_idx],  self.cur_targets[:, joint_idx].unsqueeze(1).clone()], dim=1
                        )
                    
                    cur_jt_wm_hist_qtars_ncompensated = unscale(cur_jt_wm_hist_qtars_ncompensated, self.allegro_hand_dof_lower_limits[[joint_idx]], self.allegro_hand_dof_upper_limits[[joint_idx]])
                    cur_jt_wm_hist_qtars_ncompensated = cur_jt_wm_hist_qtars_ncompensated.contiguous().view(cur_jt_wm_hist_qtars_ncompensated.size(0), -1).contiguous()
                    
                    cur_jt_wm_hist_qpos_ncompensated = cur_jt_wm_hist_qpos.clone()
                    cur_jt_wm_input_dict_ncompensated = {
                        'state': cur_jt_wm_hist_qpos_ncompensated, 'action': cur_jt_wm_hist_qtars_ncompensated
                    }
                    cur_jt_perd_output_ncompensated = self.joint_idx_to_wm[joint_idx](cur_jt_wm_input_dict_ncompensated)
                    cur_jt_perd_output_ncompensated = scale(cur_jt_perd_output_ncompensated, self.allegro_hand_dof_lower_limits[[joint_idx]], self.allegro_hand_dof_upper_limits[[joint_idx]])
                
                
                loss_diff = torch.mean((cur_jt_pred_output - self.allegro_hand_dof_pos[..., [joint_idx]]) ** 2)
                
                tot_loss_diff_backward.append(loss_diff)
                
                
                if self.progress_buf[0].item() % 8 == 0:
                    diff_abs = torch.mean(torch.abs(cur_jt_pred_output - self.allegro_hand_dof_pos[..., [joint_idx]]))
                    diff_abs_pre = torch.abs(self.real_wm_pred_next_state[..., [i_j]] - self.allegro_hand_dof_pos[..., [joint_idx]])
                    # print(f"diff_abs_pre:{ diff_abs_pre.size()}, avg: {diff_abs_pre.mean(dim=0)}")
                    diff_abs_pre = diff_abs_pre.mean(dim=0)
                    cur_compensated_target_scale = torch.mean(torch.abs(cur_out_delta_action * self.delta_action_scale))
                    diff_abs_ncompensated = torch.mean(torch.abs(cur_jt_perd_output_ncompensated - self.allegro_hand_dof_pos[..., [joint_idx]]))
                    print(f"diff_abs: {diff_abs.item()}, diff_abs_pre: {diff_abs_pre.item()}; cur_compensated_target_scale: {cur_compensated_target_scale.item()}; diff_abs_ncompensated: {diff_abs_ncompensated.item()}")
                # wm_pred_nex_state = self.real_wm_pred_next_state
                # re_pred_next_state = cur_jt_pred_output
                # selected_env_idx = 500
                # print(f"wm_pred_nex_state: {wm_pred_nex_state[selected_env_idx]}, re_pred_next_state: {re_pred_next_state[selected_env_idx]}")
                # print(f"[delta action model] abs_mean_diff: {torch.abs(self.real_wm_pred_next_state - self.allegro_hand_dof_pos[..., self.sorted_figner_joint_idxes]).mean(dim=0)}")
                
                # self.joint_idx_to_delta_action_model_optimizer[joint_idx].zero_grad()
                # loss_diff.backward()
                # self.joint_idx_to_delta_action_model_optimizer[joint_idx].step()
                
                tot_loss_diff.append(loss_diff.detach().item())
                tot_pred_joint_state.append(cur_jt_pred_output) # cur joint pred output
            
            tot_loss_diff_backward = sum(tot_loss_diff_backward)
            if self.train_action_compensator_w_finger_rew or self.action_compensator_compute_finger_rew:
                
                tot_pred_joint_state = torch.cat(tot_pred_joint_state, dim=-1) # predicted joint state --- nn_envs x nn_joints
                sim_finger_trans, sim_finger_rot_quat , sim_finger_trans_per_matrix = self.forward_pk_chain_for_finger_pos(self.allegro_hand_dof_pos)
                pred_nex_state_w_compensator = self.allegro_hand_dof_pos.clone()
                pred_nex_state_w_compensator[..., self.sorted_figner_joint_idxes] = tot_pred_joint_state.clone()
                pred_finger_trans, pred_finger_rot_quat, pred_finger_trans_per_matrix = self.forward_pk_chain_for_finger_pos(pred_nex_state_w_compensator)

                diff_finger_trans = torch.norm(pred_finger_trans - sim_finger_trans, p=2, dim=-1).mean() # (nn_envs )
                if self.train_action_compensator_w_finger_rew:
                    tot_loss_diff_backward = tot_loss_diff_backward + diff_finger_trans * 10.0
                
                # pred joint state #
                if self.progress_buf[0].item() % 8 == 0:
                    print(f"diff_finger_trans: {diff_finger_trans.item()}")
            
            for joint_idx in self.joint_idx_to_delta_action_model_optimizer:
                self.joint_idx_to_delta_action_model_optimizer[joint_idx].zero_grad()
            tot_loss_diff_backward.backward()
            for joint_idx in self.joint_idx_to_delta_action_model_optimizer:
                self.joint_idx_to_delta_action_model_optimizer[joint_idx].step()
            
        
            tot_loss_diff = sum(tot_loss_diff) / float(len(tot_loss_diff))
        
        if self.progress_buf[0].item() % 200 == 0:
            
            file_path = os.path.join(self.config['env']['output_name'], 'compensator')
            os.makedirs(file_path, exist_ok=True)
            logging_file_path  = os.path.join(self.config['env']['output_name'], 'compensator', 'wm_pred_outputs.npy')
            np.save(logging_file_path, self.joint_idx_to_wm_pred_delta_abs)
            logging_file_path  = os.path.join(self.config['env']['output_name'], 'compensator', 'delta_action.npy')
            np.save(logging_file_path, self.joint_idx_to_delta_action)
            
            # Plot 16 line charts for all joints in wm_pred_delta_abs
            import matplotlib.pyplot as plt
            import matplotlib
            matplotlib.use('Agg')  # Use non-interactive backend
            
            fig, axes = plt.subplots(4, 4, figsize=(20, 16))
            fig.suptitle('World Model Prediction Delta Absolute Values by Joint', fontsize=16)
            
            for joint_idx in range(16):
                row = joint_idx // 4
                col = joint_idx % 4
                ax = axes[row, col]
                
                if joint_idx in self.joint_idx_to_wm_pred_delta_abs and len(self.joint_idx_to_wm_pred_delta_abs[joint_idx]) > 0:
                    data = self.joint_idx_to_wm_pred_delta_abs[joint_idx]
                    ax.plot(data, linewidth=1, alpha=0.8)
                    ax.set_title(f'Joint {joint_idx}')
                    ax.set_xlabel('Steps')
                    ax.set_ylabel('Delta Abs')
                    ax.grid(True, alpha=0.3)
                else:
                    ax.text(0.5, 0.5, f'No data for Joint {joint_idx}', 
                           ha='center', va='center', transform=ax.transAxes)
                    ax.set_title(f'Joint {joint_idx}')
            
            plt.tight_layout()
            plot_file_path = os.path.join(self.config['env']['output_name'], 'compensator', 'wm_pred_delta_abs_plots.png')
            plt.savefig(plot_file_path, dpi=300, bbox_inches='tight')
            plt.close()
            
        
        
        return tot_loss_diff
    
    
    
    
    def compute_observations(self):
        
        self._refresh_gym()
        
        
        # deal with normal observation, do sliding window 
        prev_obs_buf = self.obs_buf_lag_history[:, 1:].clone()
        joint_noise_matrix = (torch.rand(self.allegro_hand_dof_pos.shape) * 2.0 - 1.0) * self.joint_noise_scale
        cur_obs_buf = unscale(
            joint_noise_matrix.to(self.device) + self.allegro_hand_dof_pos, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits
        ).clone().unsqueeze(1) # add noise to the observations #
        cur_tar_buf = self.cur_targets[:, None]
        cur_obs_buf = torch.cat([cur_obs_buf, cur_tar_buf], dim=-1)
        self.obs_buf_lag_history[:] = torch.cat([prev_obs_buf, cur_obs_buf], dim=1)

        # refill the initialized buffers
        at_reset_env_ids = self.at_reset_buf.nonzero(as_tuple=False).squeeze(-1)
        self.obs_buf_lag_history[at_reset_env_ids, :, 0:16] = unscale(
            self.allegro_hand_dof_pos[at_reset_env_ids], self.allegro_hand_dof_lower_limits,
            self.allegro_hand_dof_upper_limits
        ).clone().unsqueeze(1)
        self.obs_buf_lag_history[at_reset_env_ids, :, 16:32] = self.allegro_hand_dof_pos[at_reset_env_ids].unsqueeze(1)
        # t_buf = (self.obs_buf_lag_history[:, -3:].reshape(self.num_envs, -1)).clone()
        t_buf = (self.obs_buf_lag_history[:, -self.lag_history_buf_length:].reshape(self.num_envs, -1)).clone()


        self.obs_buf_lag_history_qpos[:, :-1] = self.obs_buf_lag_history_qpos[:, 1: ].clone()
        self.obs_buf_lag_history_qtars[:, :-1] = self.obs_buf_lag_history_qtars[:, 1: ].clone()
        self.obs_buf_lag_history_compensated_qtars[:, :-1] = self.obs_buf_lag_history_compensated_qtars[:, 1:].clone()
        self.obs_buf_lag_history_qpos[:, -1, :] = self.allegro_hand_dof_pos.clone()
        self.obs_buf_lag_history_qtars[:, -1, :] = self.cur_targets.clone()
        self.obs_buf_lag_history_compensated_qtars[:, -1, :] = self.cur_fullhand_compensated_target.clone() 

        # 
        if self.train_action_compensator and self.train_action_compensator_uan:
            history_qtars, history_qpos = self.obs_buf_lag_history_qtars[:, -20:].clone(), self.obs_buf_lag_history_qpos[:, -20:].clone() # nn_envs x nn_hist_len x 16; nn_envs x nn_hist_len x 16
            # nn hist len x 
            hsitory_e_qpos_to_qtars = history_qtars - history_qpos # nn_envs x nn_hist_len x 16
            flatten_e = hsitory_e_qpos_to_qtars.contiguous().view(hsitory_e_qpos_to_qtars.size(0), -1).contiguous()
            self.obs_buf[:, : flatten_e.size(1)] = flatten_e # copy flatten_e to obs_buf #
            cur_obs_st = flatten_e.size(1)
        elif self.train_action_compensator and self.fingertip_only_action_compensator:
            fingertip_joints_obs = self.obs_buf_lag_history[:, -1, :16][..., [3, 7, 11, 15]].clone()
            self.obs_buf[:, :4] = fingertip_joints_obs.clone()
            cur_obs_st = 4
        elif self.train_action_compensator and self.action_compensator_not_use_history:
            self.obs_buf[:, :16] = self.obs_buf_lag_history[:, -1, :16].clone()
            cur_obs_st = 16
        else:
            self.obs_buf[:, :t_buf.shape[1]] = t_buf
            cur_obs_st = t_buf.shape[1]
        self.at_reset_buf[at_reset_env_ids] = 0
        
        
        if self.adjustable_rot_vel:
            self.obs_buf[:, t_buf.shape[1]: t_buf.shape[1] + 3] = self.envs_rot_vel.unsqueeze(1) * self.rot_axis_tsr.unsqueeze(0)
            cur_obs_st = cur_obs_st + 3
        
        if self.add_fingertip_obs:
            self.fingertip_pos = self.rigid_body_states[:, self.fingertip_handles, :3].clone() # nn_envs x 4 x 3 #
            self.fingertip_pos_flatten = self.fingertip_pos.contiguous().view(self.fingertip_pos.size(0), -1)
            # TODO: adjust the noise scale for fingertip observations
            fingertip_pos_noise_matrix = (torch.rand(self.fingertip_pos_flatten.shape).to(self.device) * 2.0 - 1.0) * self.joint_noise_scale * 0.1
            fingertip_pos_flatten_noised = self.fingertip_pos_flatten + fingertip_pos_noise_matrix
            
            self.obs_buf[:, cur_obs_st: cur_obs_st + self.fingertip_pos_flatten.size(1)] = fingertip_pos_flatten_noised.clone()
            cur_obs_st = cur_obs_st + self.fingertip_pos_flatten.size(1)

        if self.add_fingertip_ornt_obs:
            self.fingertip_ornt = self.rigid_body_states[:, self.fingertip_handles, 3: 7].clone() # nn_envs x 4 x 3 #
            self.fingertip_ornt_flatten = self.fingertip_ornt.contiguous().view(self.fingertip_ornt.size(0), -1)
            # fingertip_ornt_noise_matrix = (torch.rand(self.fingertip_ornt_flatten.shape).to(self.device) * 2.0 - 1.0) * self.joint_noise_scale * 0.1
            # fingertip_ornt_flatten_noised = self.fingertip_ornt_flatten + fingertip_ornt_noise_matrix
            self.obs_buf[:, cur_obs_st: cur_obs_st + self.fingertip_ornt_flatten.size(1)] = self.fingertip_ornt_flatten.clone()
            cur_obs_st = cur_obs_st + self.fingertip_ornt_flatten.size(1)
        
        if self.add_fingertip_state_vel_obs:
            # 96 + 16 + 13 * 4 = 164
            self.fingertip_state = self.rigid_body_states[:, self.fingertip_handles, :13].clone()
            self.fingertip_state_flatten = self.fingertip_state.contiguous().view(self.fingertip_state.size(0), -1)
            fingertip_state_noise_matrix = (torch.rand(self.fingertip_state_flatten.shape).to(self.device) * 2.0 - 1.0) * self.joint_noise_scale * 0.1
            fingertip_state_flatten_noised = self.fingertip_state_flatten + fingertip_state_noise_matrix
            self.obs_buf[:, cur_obs_st: cur_obs_st + self.fingertip_state_flatten.size(1)] = fingertip_state_flatten_noised.clone()
            cur_obs_st = cur_obs_st + self.fingertip_state_flatten.size(1)
            shadow_hand_dof_vel_noised = self.allegro_hand_dof_vel + (torch.rand(self.allegro_hand_dof_vel.shape).to(self.device) * 2.0 - 1.0) * self.joint_noise_scale * 0.1
            self.obs_buf[:, cur_obs_st: cur_obs_st + self.allegro_hand_dof_vel.size(1)] = shadow_hand_dof_vel_noised.clone()

        if self.add_object_state_obs:
            self.object_state = self.rigid_body_states[:, self.object_rb_handles, :13].clone()
            self.object_state_flatten = self.object_state.contiguous().view(self.object_state.size(0), -1)
            object_state_noise_matrix = (torch.rand(self.object_state_flatten.shape).to(self.device) * 2.0 - 1.0) * self.joint_noise_scale * 0.1
            object_state_flatten_noised = self.object_state_flatten + object_state_noise_matrix
            self.obs_buf[:, cur_obs_st: cur_obs_st + self.object_state_flatten.size(1)] = object_state_flatten_noised.clone()
            cur_obs_st = cur_obs_st + self.object_state_flatten.size(1)
            
            
        if self.hand_tracking:
            self.obs_buf[:, cur_obs_st: cur_obs_st + self.hand_tracking_targets.shape[1] * 2] = torch.cat([self.hand_tracking_targets.clone(), self.hand_tracking_targets - self.allegro_hand_dof_pos], dim=-1)
            cur_obs_st = cur_obs_st + self.hand_tracking_targets.shape[1] * 2
            
        # # quat_dist = quat_mul(self.object_rot, quat_conjugate(self.goal_rot)) #
        # self.add_obj_goal_observations = self.config['env'].get('addObjGoalObservations', False) 
        
        if self.add_obj_goal_observations:
            obj_pose_goal = torch.cat(
                [ self.guiding_pose, quat_mul(self.object_rot, quat_conjugate(self.guiding_pose)) ], dim=-1
            )
            self.obs_buf[:, cur_obs_st: cur_obs_st + obj_pose_goal.size(1)] = obj_pose_goal.clone()
            cur_obs_st = cur_obs_st + obj_pose_goal.size(1)

        if self.grasp_to_grasp:
            obj_pose_goal = torch.cat(
                [ self.goal_object_pose[..., :3] - self.object_pose[..., :3], 
                #  quat_mul(self.object_rot, quat_conjugate(self.goal_object_pose[..., 3:])) 
                 quat_mul(self.goal_object_pose[..., 3: 7], quat_conjugate(self.object_rot)) 
                 ], dim=-1
            )
            self.obs_buf[:, cur_obs_st: cur_obs_st + obj_pose_goal.size(1)] = obj_pose_goal.clone()
            cur_obs_st = cur_obs_st + obj_pose_goal.size(1)

        if self.add_force_obs:
            self.obs_buf[:, cur_obs_st: cur_obs_st + self.dof_force_tensor.shape[1]] = self.dof_force_tensor.clone()
            cur_obs_st = cur_obs_st + self.dof_force_tensor.shape[1]
            self.obs_buf[:, cur_obs_st: cur_obs_st + self.vec_sensor_tensor.shape[1]] = self.vec_sensor_tensor.clone()
            cur_obs_st = cur_obs_st + self.vec_sensor_tensor.shape[1]
        
        if self.add_contact_force_with_binary_contacts:
            norm_contacts = torch.norm(self.contact_forces, dim=-1)
            self.contact_thresh = 0.05
            contacts = torch.where(norm_contacts >= self.contact_thresh, 1.0, 0.0)
            contact_force_with_contacts = torch.cat(
                [ self.contact_forces, contacts.unsqueeze(-1) ], dim=-1
            )
            flatten_contact_force_with_contacts = contact_force_with_contacts.contiguous().view(self.contact_forces.size(0), -1)
            self.obs_buf[:, cur_obs_st: cur_obs_st + flatten_contact_force_with_contacts.size(1)] = flatten_contact_force_with_contacts.clone()
            cur_obs_st = cur_obs_st + flatten_contact_force_with_contacts.size(1) # flatten contact force # flatten contact force #
        elif self.add_contact_force_obs:
            flatten_contact_force = self.contact_forces.clone().contiguous().view(self.contact_forces.size(0), -1)
            self.obs_buf[:, cur_obs_st: cur_obs_st + flatten_contact_force.size(1)] = flatten_contact_force.clone()
            cur_obs_st = cur_obs_st + flatten_contact_force.size(1)
        
        if self.change_rot_dir or self.randomize_rot_dir:
            # print(f"adding rot axis buf to the observations")
            if self.train_goal_conditioned:
                self.obs_buf[:, cur_obs_st: cur_obs_st + 4] = self.target_obj_pose_buf.clone()
                cur_obs_st = cur_obs_st + 4
            else:
                self.obs_buf[:, cur_obs_st: cur_obs_st + 3] = self.rot_axis_buf.clone()
                cur_obs_st = cur_obs_st + 3
        
        
        if self.recovery_training:
            self.obs_buf[:, cur_obs_st: cur_obs_st + 7] = self.target_root_state_tensor[self.object_indices, :7].clone()
            cur_obs_st = cur_obs_st + 7
        
         
        if self.train_action_compensator and (not self.action_compensator_not_using_real_actions):
            cur_replay_actions = batched_index_select(self.envs_replay_qtars, indices=self.progress_buf.unsqueeze(1), dim=1).squeeze(1)
            if (not self.train_action_compensator_uan) and self.fingertip_only_action_compensator:
                self.obs_buf[:, cur_obs_st: cur_obs_st + 4] = cur_replay_actions[..., [3, 7, 11, 15]].clone()
                cur_obs_st = cur_obs_st + 4
            else:
                self.obs_buf[:, cur_obs_st: cur_obs_st + cur_replay_actions.shape[1]] = cur_replay_actions.clone()
                cur_obs_st = cur_obs_st + cur_replay_actions.shape[1]
        
        self.proprio_hist_buf[:] = self.obs_buf_lag_history[:, -self.prop_hist_len:].clone()
        self._update_priv_buf(env_id=range(self.num_envs), name='obj_position', value=self.object_pos.clone())
        
        
        if self.real_to_sim_auto_tune:
            self.tested_envs_states[:, self.testing_traj_idx, self.testing_traj_ts, :] = self.allegro_hand_dof_pos.clone()
            self.tested_envs_obj_states[:, self.testing_traj_idx, self.testing_traj_ts, :] = self.object_pose.clone()
        
        if self.add_obj_features:
            # print(f"adding obj features to the observations")
            self.obs_buf[:, cur_obs_st: cur_obs_st + self.obj_feature_dim] = self.envs_obj_features.clone()
            cur_obs_st = cur_obs_st + self.obj_feature_dim
        
        
        if self.tune_bc_model:
            
            self.original_obs_buf = self.obs_buf.clone()
            
            # compute `actor_obs` and `critic_obs` # # 
            hist_qpos = self.obs_buf_lag_history_qpos[:, -self.bc_model_history_length: ].clone()
            hist_qtars = self.obs_buf_lag_history_qtars[:, -self.bc_model_history_length: ].clone()
            
            unscaled_hist_qpos = unscale(hist_qpos, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)
            
            flatten_hist_qpos = hist_qpos.contiguous().view(hist_qpos.shape[0], -1).contiguous()
            flatten_hist_qtars = hist_qtars.contiguous().view(hist_qtars.shape[0], -1).contiguous()
            
            hist_info = torch.cat([flatten_hist_qpos, flatten_hist_qtars], dim=-1)
            
            
            value_net_hist_info = torch.cat(
                [ unscaled_hist_qpos, hist_qtars ], dim=-1
            )
            value_net_hist_info  = value_net_hist_info.contiguous().view(value_net_hist_info.size(0), -1).contiguous()
            
            
            self.obs_buf[:, :  hist_info.size(-1) + value_net_hist_info.size(-1)] = torch.cat(
                [ hist_info, value_net_hist_info ], dim=-1
            )
            
            
        
        
        if self.test: 
            cur_allegro_hand_dof_pos_np = self.allegro_hand_dof_pos.detach().cpu().numpy()
            cur_obj_pose_np = self.object_pose.detach().cpu().numpy()
            self.ts_to_reset_info[self.ref_ts] = {
                'shadow_hand_dof_pos': cur_allegro_hand_dof_pos_np,
                'object_pose': cur_obj_pose_np,
                'shadow_hand_dof_tars': self.cur_targets.detach().cpu().numpy(),
                'hand_pose': self.hand_pose_tsr.detach().cpu().numpy(),
            }
            self.ref_ts += 1
            if self.ref_ts >= 300: 
                to_sv_path = f'cache/{self.grasp_cache_name}_eval_res.npy'
                self._try_save(to_sv_path)
                # self.ref_ts = 0
                print(f"Saved {self.ref_ts} samples to {to_sv_path}")
        if self.evaluate:
            expanded_aranged_idxes = torch.arange(self.maxx_episode_length).unsqueeze(0).contiguous().repeat(self.num_envs, 1).contiguous().to(self.device)
            expanded_idxes_hand_qpos = expanded_aranged_idxes.unsqueeze(-1).repeat(1, 1, self.num_allegro_hand_dofs).contiguous()
            expanded_idxes_obj_pose = expanded_aranged_idxes.unsqueeze(-1).repeat(1, 1, 7).contiguous()
            expanded_idxes_extrin = expanded_aranged_idxes.unsqueeze(-1).repeat(1, 1, self.extrin_dim).contiguous()
            expanded_idxes_rot_axis = expanded_aranged_idxes.unsqueeze(-1).repeat(1, 1, 3).contiguous()
            expanded_idxes_reward = expanded_aranged_idxes.unsqueeze(-1).repeat(1, 1, 1).contiguous()
            expanded_idxes_values = expanded_aranged_idxes.unsqueeze(-1).repeat(1, 1, 1).contiguous()
            ### self.object_pose = self.root_state_tensor[self.object_indices, 0:7] #
            
            cur_hand_root_pose = self.root_state_tensor[self.hand_indices, 0:7].clone()
            cur_hand_root_ornt = cur_hand_root_pose[..., 3: ] # (nn_envs, 4) -- xyzw format data for hand root pose #
            expanded_idxes_hand_root_ornt = expanded_aranged_idxes.unsqueeze(-1).contiguous().repeat(1, 1, 4).contiguous()
            self.hand_root_ornt_buf  = torch.where(
                expanded_idxes_hand_root_ornt == self.progress_buf.unsqueeze(-1).unsqueeze(-1),
                cur_hand_root_ornt.unsqueeze(1).contiguous().repeat(1, self.maxx_episode_length, 1).contiguous(), self.hand_root_ornt_buf
            )
            
            
            # object_pose_buf #
            if self.omni_wrist_ornt:
                transformed_obj_rot = quat_mul(quat_conjugate(self.rnd_rot_tensor), self.object_rot)
                transformed_obj_pose  = torch.cat(
                    [ self.object_pos, transformed_obj_rot ], dim=-1
                )
            else:
                transformed_obj_pose = self.object_pose.clone()
            
            
            # cur_progress_buf = torch.clamp(
            #     self.progress_buf, min=0, max=self.maxxepi
            # )
            self.shadow_hand_dof_pos_buf = torch.where(
                expanded_idxes_hand_qpos == self.progress_buf.unsqueeze(-1).unsqueeze(-1),
                self.allegro_hand_dof_pos.unsqueeze(1).contiguous().repeat(1, self.maxx_episode_length, 1).contiguous(), self.shadow_hand_dof_pos_buf
            )
            # hand dof tars #
            self.shadow_hand_dof_tars_buf = torch.where(
                expanded_idxes_hand_qpos == self.progress_buf.unsqueeze(-1).unsqueeze(-1),
                self.cur_targets.unsqueeze(1).contiguous().repeat(1, self.maxx_episode_length, 1).contiguous(), self.shadow_hand_dof_tars_buf
            )
            # self.object_pose_buf = torch.where(
            #     expanded_idxes_obj_pose == self.progress_buf.unsqueeze(-1).unsqueeze(-1),
            #     self.object_pose.unsqueeze(1).contiguous().repeat(1, self.maxx_episode_length, 1).contiguous(), self.object_pose_buf
            # )
            self.object_pose_buf = torch.where(
                expanded_idxes_obj_pose == self.progress_buf.unsqueeze(-1).unsqueeze(-1),
                transformed_obj_pose.unsqueeze(1).contiguous().repeat(1, self.maxx_episode_length, 1).contiguous(), self.object_pose_buf
            )
            self.rot_axis_totep_buf = torch.where(
                expanded_idxes_rot_axis == self.progress_buf.unsqueeze(-1).unsqueeze(-1),
                self.rot_axis_buf.unsqueeze(1).contiguous().repeat(1, self.maxx_episode_length, 1).contiguous(), self.rot_axis_totep_buf
            )
            
            self.reward_buf = torch.where(
                expanded_idxes_reward == self.progress_buf.unsqueeze(-1).unsqueeze(-1),
                self.rew_buf.unsqueeze(1).unsqueeze(1).contiguous().repeat(1, self.maxx_episode_length, 1).contiguous(), self.reward_buf
            )
            
            self.ep_rotr_buf = torch.where(
                expanded_idxes_reward == self.progress_buf.unsqueeze(-1).unsqueeze(-1),
                self.rotr_buf.unsqueeze(1).unsqueeze(1).contiguous().repeat(1, self.maxx_episode_length, 1).contiguous(), self.ep_rotr_buf
            )
            self.ep_rotp_buf = torch.where(
                expanded_idxes_reward == self.progress_buf.unsqueeze(-1).unsqueeze(-1),
                self.rotp_buf.unsqueeze(1).unsqueeze(1).contiguous().repeat(1, self.maxx_episode_length, 1).contiguous(), self.ep_rotp_buf
            )
            
            
            self.value_buf = torch.where(
                expanded_idxes_values == self.progress_buf.unsqueeze(-1).unsqueeze(-1),
                self.value_vals.unsqueeze(1).contiguous().repeat(1, self.maxx_episode_length, 1).contiguous(), self.value_buf
            )
            
            try:
                self.extrin_buf = torch.where(
                    expanded_idxes_extrin == self.progress_buf.unsqueeze(-1).unsqueeze(-1),
                    self.extrin.unsqueeze(1).contiguous().repeat(1, self.maxx_episode_length, 1).contiguous(), self.extrin_buf
                )
            except:
                pass
            
            
    def _try_save(self, path):
        try:
            np.save(path, self.ts_to_reset_info)
        except Exception as e:
            tprint(f"Failed to save: {e}")
    
    def compute_reward(self, actions):
        
        if not (self.randomize_rot_dir):
            if self.custm_rot_axis:
                self.rot_axis_buf[:, :] = self.custm_rot_axis_tsr.unsqueeze(0).repeat(self.num_envs, 1).contiguous()
            else:

                if self.rot_axis == 'z':
                    self.rot_axis_buf[:, -1] = self.rot_axis_mult
                elif self.rot_axis == 'y':
                    self.rot_axis_buf[:, 1] = self.rot_axis_mult
                elif self.rot_axis == 'x':
                    self.rot_axis_buf[:, 0] = self.rot_axis_mult
                    
                if self.change_rot_dir:
                    if self.change_target_rot_dir == 'z':
                        self.rot_axis_buf[self.progress_buf >= self.change_rot_dir_period, -1] = self.change_target_rot_dir_mult
                    elif self.change_target_rot_dir == 'y':
                        self.rot_axis_buf[self.progress_buf >= self.change_rot_dir_period, 1] = self.change_target_rot_dir_mult
                    elif self.change_target_rot_dir == 'x':
                        self.rot_axis_buf[self.progress_buf >= self.change_rot_dir_period, 0] = self.change_target_rot_dir_mult
            
        
        
        # Pose penalty #
        pose_diff_penalty = ((self.allegro_hand_dof_pos - self.init_pose_buf) ** 2).sum(-1)
        # Torque penalty #
        torque_penalty = (self.torques ** 2).sum(-1)
        # Work penalty #
        work_penalty = ((self.torques * self.dof_vel_finite_diff).sum(-1)) ** 2
        # Compute offset in radians. Radians -> radians / sec
        angdiff = quat_to_axis_angle(quat_mul(self.object_rot, quat_conjugate(self.object_rot_prev)))
        object_angvel = angdiff / (self.control_freq_inv * self.dt)
        if self.omni_wrist_ornt:
            object_angvel = quat_apply(quat_conjugate(self.rnd_rot_tensor), object_angvel)
        vec_dot = (object_angvel * self.rot_axis_buf).sum(-1)
        
        if self.evaluate:
            self.rotr_buf[:] = vec_dot.clone()
            # self.rotp_buf[:] = object_angvel.clone()
            cross_dot = torch.cross(self.rot_axis_buf, object_angvel, dim=-1)
            cross_dot = torch.norm(cross_dot, p=2, dim=-1)
            self.rotp_buf[:] = cross_dot.clone()
        
        
        if self.adjustable_rot_vel:
            
            
            rotate_reward = torch.norm(self.envs_rot_vel.unsqueeze(1) * self.rot_axis_tsr.unsqueeze(0) - object_angvel, p=2, dim=-1)
            
            vec_rot_rew_clip =  torch.clip(vec_dot, max=self.angvel_clip_max, min=self.angvel_clip_min)
            
            rotate_reward =  -self.rot_vel_coef * rotate_reward + vec_rot_rew_clip
            
            # rotate_reward = vec_rot_rew_clip
            
            self.rot_axis_step += 1
            if self.rot_axis_step >= self.change_rot_axis_period:
                self.rot_axis_step = 0
                self.envs_rot_vel = torch.rand(self.num_envs).to(self.device).float() * (self.max_rot_vel - self.min_rot_vel) + self.min_rot_vel
        else:
            rotate_reward = torch.clip(vec_dot, max=self.angvel_clip_max, min=self.angvel_clip_min)
        
        
        
            
        
        # linear velocity: use position difference instead of self.object_linvel
        object_linvel = ((self.object_pos - self.object_pos_prev) / (self.control_freq_inv * self.dt)).clone()
        object_linvel_penalty = torch.norm(object_linvel, p=1, dim=-1)
        
        self.object_angvel = object_angvel.detach()
        
        # add aux pose guidance #
        if self.add_aux_pose_guidance:
            quat_diff = quat_mul(self.object_rot, quat_conjugate(self.guiding_pose))
            rot_dist = 2.0 * torch.asin(torch.clamp(torch.norm(quat_diff[:, 0:3], p=2, dim=-1), max=1.0))
            # quat_diff_scale = 0.3 # quat_diff_scale = 0.3 #
            quat_diff_scale = 0.03
            pose_guidance_rew = quat_diff_scale / (rot_dist + 1e-6)
            pose_guidance_rew = torch.clip(pose_guidance_rew, max=0.3, min=0.0)
            pose_guidance_bonus_rew = (rot_dist < self.rot_radian_threshold).float() * 0.3
            
            self.rew_buf_aux_pose_guidance[:] = pose_guidance_rew.clone()
            self.rew_buf_aux_pose_guidance_bonus[:] = pose_guidance_bonus_rew.clone()
            
            pose_guidance_rew = pose_guidance_rew + pose_guidance_bonus_rew
        
        
        if self.train_goal_conditioned:
            rotate_reward = rotate_reward * 0.0
            # pose_guidance_rew = pose_guidance_rew * 0.0
            
            
            
        if self.grasp_to_grasp:
            cur_obj_trans = self.object_pos
            cur_obj_rot = self.object_rot
            goal_obj_trans = self.goal_object_pose[..., :3] 
            goal_obj_rot = self.goal_object_pose[..., 3:7]
            goal_hand_pose = self.goal_hand_pose[..., :] #
            
            goal_dist = torch.norm(cur_obj_trans - goal_obj_trans, dim=-1, p=2, keepdim=False)
            quat_diff = quat_mul(cur_obj_rot, quat_conjugate(goal_obj_rot))
            rot_dist = 2.0 * torch.asin(torch.clamp(torch.norm(quat_diff[:, 0:3], p=2, dim=-1), max=1.0))
            
            delta_rot_quat = quat_mul(goal_obj_rot, quat_conjugate(cur_obj_rot)) # compute the delta rotation #
            delta_rot_angle = quat_to_axis_angle(delta_rot_quat)
            delta_rot_dir = delta_rot_angle / torch.clamp(torch.norm(delta_rot_angle, p=2, dim=-1, keepdim=True), min=1e-6)
            dot_anglvel_w_delta_rot_dir = (object_angvel * delta_rot_dir).sum(-1)
            dot_anglvel_w_delta_rot_dir = torch.clip(dot_anglvel_w_delta_rot_dir, max=self.angvel_clip_max, min=self.angvel_clip_min)
            
            ## NOTE: v2 reward function : difference from v1 reward function -- we add the translation distance reward ##
            ## NOTE: we add the translation distance condition ##
            rot_dist_rew = 2.0 * (0.0 - rot_dist)
            trans_dist_rew = (0.0 - goal_dist)
            # obj_dist_rew = 1 * (0.0 - 2 * goal_dist) + rot_dist_rew 
            # obj_dist_rew = rot_dist_rew 
            obj_dist_rew = (rot_dist_rew + trans_dist_rew) * 0.00001
            
            two_degree_rot_diff = 5.0 / 180.0 * 3.1415926535
            goal_dist_threshold = 0.01
            obj_bonus_flat = (rot_dist <= two_degree_rot_diff).int() 
            # + (goal_dist <= goal_dist_threshold).int()
            obj_bonus_flat = obj_bonus_flat + (goal_dist <= goal_dist_threshold).int()
            obj_bonus_flat = (obj_bonus_flat == 2).int() 
            
            bonus = torch.zeros_like(goal_dist)
            bonus = torch.where(obj_bonus_flat == 1, torch.ones_like(bonus), bonus)

            object_linvel_penalty = object_linvel_penalty * 0.0
            rotate_reward = bonus + obj_dist_rew + dot_anglvel_w_delta_rot_dir
            
            # diff hand pose as the reward #
            diff_hand_pose = torch.norm(self.allegro_hand_dof_pos - goal_hand_pose, p=2, dim=-1)
            diff_hand_pose_rew = 3 / float(16) * 0.3 * (-1) * diff_hand_pose
            
            ##### Command the following line to activate the object-reward-only setting #####
            rotate_reward = rotate_reward + diff_hand_pose_rew #

            # # threshold? = 0.05 #
            # # hand tracking #
            # diff_hand_pose_thres = 0.05
            # to_grasp_succ = (obj_bonus_flat == 1).int() + (diff_hand_pose_thres <= diff_hand_pose).int()
            # to_grasp_succ = to_grasp_succ == 2
        
        
        
        if self.hand_tracking:
            self.hand_tracking_period_count[:] += 1
            reset_target_env_idxes = self.hand_tracking_period_count >= self.hand_tracking_target_upd_steps
            
            reset_target_nn_envs = int(reset_target_env_idxes.float().sum().item())
            if reset_target_nn_envs > 0:
                hand_pose_rand_floats = torch_rand_float(-1.0, 1.0, (reset_target_nn_envs, self.num_allegro_hand_dofs), device=self.device)
                target_hand_pose = self.allegro_hand_dof_pos[reset_target_env_idxes, :] + hand_pose_rand_floats # * 0.25
                target_hand_pose = tensor_clamp(target_hand_pose, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)
                self.hand_tracking_targets[reset_target_env_idxes, :self.num_allegro_hand_dofs] = target_hand_pose
                self.hand_tracking_period_count[reset_target_env_idxes] = 0
            
            hand_tracking_diff = torch.norm(self.allegro_hand_dof_pos - self.hand_tracking_targets, p=2, dim=-1)
            hand_tracking_diff_rew = 3/float(16) * 0.3 * (-1) * hand_tracking_diff
            rotate_reward = hand_tracking_diff_rew
            object_linvel_penalty = object_linvel_penalty * 0.0
            pose_diff_penalty = pose_diff_penalty  * 0.0
            
        
        
        if self.recovery_training:
            
            ########## Difference between the current object pose and the target object pose ##########
            # rotation reward = rotate_reward * 1.0
            # object linvel penalty = object_linvel_penalty * (-0.3)
            diff_obj_pos_w_target_pos = torch.norm(self.object_pos - self.target_root_state_tensor[self.object_indices, :3], p=2, dim=-1)
            # diff_obj_rot_w_target_rot = quat_diff(self.object_rot, self.target_root_state_tensor[self.object_indices, 3:7])
            quat_diff = quat_mul(self.object_rot, quat_conjugate(self.target_root_state_tensor[self.object_indices, 3:7]))
            rot_dist = 2.0 * torch.asin(torch.clamp(torch.norm(quat_diff[:, 0:3], p=2, dim=-1), max=1.0))
            
            object_linvel_penalty = diff_obj_pos_w_target_pos  # + rot_dist * 0.03
            object_linvel_penalty = object_linvel_penalty * 0.0 # try to zero out the linear velocity penalty reward #
            ########## Difference between the current object pose and the target object pose ##########
            
            
            # ##### Not a good bonus reward here #####
            # bonus = torch.zeros_like(object_linvel_penalty)
            # obj_translation_succ_thres = self.recovery_succ_obj_pos_thres #  0.001
            # bonus = torch.where(diff_obj_pos_w_target_pos <= obj_translation_succ_thres, torch.ones_like(bonus), bonus) # 
            # ##### Not a good bonus reward here #####
            

            ############################## Bonus reward related to pos-target distance ##############################
            
            ##### Initialize the positive reward buffer #####
            bonus = torch.zeros_like(object_linvel_penalty)
            ##### Initialize the positive reward buffer #####
            
            
            # ##### Add the rotation reward ##### #
            bonus = bonus + rotate_reward.clone()
            # ##### Add the rotation reward ##### #
            
            
            # ########## Positive reward related to the pos-to-target difference ##########
            # e_mult_coef = 100 # -1.0 * e_mult_coef * diff_obj_pos_w_target_pos
            # bonus = bonus + torch.exp(-1.0 * e_mult_coef * diff_obj_pos_w_target_pos)
            # ########## Positive reward related to the pos-to-target difference ##########
            
            
            ########## Positive reward that encourages the transition from the current position to the target position ##########
            cur_to_target_lin_trans_dir = self.target_root_state_tensor[self.object_indices, :3] - self.object_pos
            cur_to_target_lin_trans_dir = cur_to_target_lin_trans_dir / torch.clamp(torch.norm(cur_to_target_lin_trans_dir, dim=-1, p=2, keepdim=True), min=1e-6)
            
            # dot_linvel_w_target_trans_dir = torch.sum(object_linvel * cur_to_target_lin_trans_dir, dim=-1) 
            
            dot_linvel_w_target_trans_dir = object_linvel[..., 2]
            bonus = bonus + dot_linvel_w_target_trans_dir
            ########## Positive reward that encourages the transition from the current position to the target position ##########
            
            
            ########## Positive reward related to the true bonus reward for the pos-to-target difference ##########
            true_bonus = torch.zeros_like(object_linvel_penalty)
            obj_translation_succ_thres = self.recovery_succ_obj_pos_thres
            obj_translation_succ_bonus = 10.0
            true_bonus = torch.where(diff_obj_pos_w_target_pos <= obj_translation_succ_thres, torch.ones_like(true_bonus) * obj_translation_succ_bonus, true_bonus)
            bonus = bonus + true_bonus
            ########## Positive reward related to the true bonus reward for the pos-to-target difference ##########
            
            
            # ########## Add the fall penalty reward ##########
            # fall_z_threshold = 0.355 #  36
            # obj_fall_cond = self.object_pos[..., 2] < fall_z_threshold
            # fall_penalty = torch.zeros_like(object_linvel_penalty)
            # fall_penalty = torch.where(obj_fall_cond, torch.ones_like(fall_penalty) * (-1000.0), fall_penalty)
            # bonus = bonus + fall_penalty
            # ########## Add the fall penalty reward ##########
            
            
            # ########## Add l2 goal reward ##########
            # diff_obj_pos_w_target_pos = torch.sum((self.object_pos - self.target_root_state_tensor[self.object_indices, :3]) ** 2, dim=-1)
            # bonus = bonus + (-1.0 * diff_obj_pos_w_target_pos)
            # ########## Add l2 goal reward ##########
            
            # selected_lin_trans_dir = cur_to_target_lin_trans_dir[100]
            # print(f"Debugging -- selected_lin_trans_dir: {selected_lin_trans_dir}")
            ############################## Bonus reward related to pos-target distance ##############################
 
            rotate_reward = bonus
        
        
        if self.add_translation:
            bonus = torch.zeros_like(object_linvel_penalty)
            
            # ##### Add the rotation reward ##### #
            bonus = bonus + rotate_reward.clone()
            # ##### Add the rotation reward ##### #
            
            if self.omni_wrist_ornt:
                object_linvel = quat_apply(quat_conjugate(self.rnd_rot_tensor), object_linvel)
            
            dot_linvel_w_target_trans_dir = torch.sum(object_linvel * self.trans_dir_buf, dim=-1)
            
            
            # dot_linvel_w_target_trans_dir = torch.clip(dot_linvel_w_target_trans_dir, max=self.angvel_clip_max / 3, min=self.angvel_clip_min)
            dot_linvel_w_target_trans_dir = torch.clip(dot_linvel_w_target_trans_dir, max=self.angvel_clip_max / 10, min=self.angvel_clip_min)
            
            # dot_linvel_w_target_trans_dir = dot_linvel_w_target_trans_dir * 5.0
            dot_linvel_w_target_trans_dir = dot_linvel_w_target_trans_dir * 10 #  3
            
            bonus = bonus + dot_linvel_w_target_trans_dir
            
            if self.progress_buf[0].item() % 8 == 0:
                envs_mean_rotation_rew = rotate_reward.mean(dim=0)
                envs_mean_trans_rew = dot_linvel_w_target_trans_dir.mean(dim=0)
                print(f"envs_mean_rotation_rew: {envs_mean_rotation_rew.detach().mean().item()}, envs_mean_trans_rew: {envs_mean_trans_rew.detach().mean().item()}")
            
            object_linvel_penalty = object_linvel_penalty * 0.0
            rotate_reward = bonus
            
        
        if self.train_action_compensator:
            if self.fingertip_only_action_compensator:
                cur_real_states = batched_index_select(self.envs_replay_qpos, indices=self.progress_buf.unsqueeze(1), dim=1).squeeze(1)[..., [3, 7, 11, 15]]
                diff_states_w_real_states = torch.sum((cur_real_states - self.allegro_hand_dof_pos[..., [3, 7, 11, 15]]) ** 2, dim=-1)
                # pose_diff_penalty = diff_states_w_real_states # * 0.1
            else: 
                cur_real_states = batched_index_select(self.envs_replay_qpos, indices=self.progress_buf.unsqueeze(1), dim=1).squeeze(1)
                diff_states_w_real_states = torch.sum((cur_real_states - self.allegro_hand_dof_pos) ** 2, dim=-1)
                # pose_diff_penalty = diff_states_w_real_states # * 0.1
            
            if self.train_action_compensator_w_obj_motion_pred:
                sim_input, real_input = self._get_obj_motion_pred_input()
                diff_states_w_real_states = diff_states_w_real_states + torch.sum((real_input - sim_input) ** 2, dim=-1).detach()
                # pose_diff_penalty = pose_diff_penalty +  diff_states_w_real_states.detach()
            pose_diff_penalty = diff_states_w_real_states
            
            object_linvel_penalty = object_linvel_penalty * 0.0
            if not self.action_compensator_w_obj:
                rotate_reward = rotate_reward * 0.0
            else:
                # compensator_w_obj_rew_type, 'angvel' 'angvelthres' 'notfalling'
                if self.compensator_w_obj_rew_type == 'angvel':
                    rotate_reward = rotate_reward
                elif self.compensator_w_obj_rew_type == 'angvelthres':
                    rotate_reward = torch.where(rotate_reward >= 0.1, torch.ones_like(rotate_reward) * 0.1, rotate_reward)
                elif self.compensator_w_obj_rew_type == 'notfalling':
                    rotate_reward = torch.ones_like(rotate_reward) * 0.1
                else:
                    raise ValueError(f"Unknown compensator_w_obj_rew_type: {self.compensator_w_obj_rew_type}")
        
        
        
        if self.train_action_compensator_w_real_wm:
            diff_real_trans_nex_state_w_cur_state = torch.norm(self.real_wm_pred_next_state - self.allegro_hand_dof_pos[..., self.sorted_figner_joint_idxes], p=2, dim=-1)
            if self.progress_buf[0].item() % 8 == 0:
                diff_wo_compensator = torch.norm(self.real_wm_pred_next_state_orijoints - self.allegro_hand_dof_pos[..., self.sorted_figner_joint_idxes], p=2, dim=-1)
                # print(f"abs_mean_diff: {torch.abs(self.real_wm_pred_next_state - self.allegro_hand_dof_pos[..., self.sorted_figner_joint_idxes]).mean(dim=0)}, obs_mean_diff_wo_compensator: {torch.abs(diff_wo_compensator).mean(dim=0)}")
            # diff_real_trans_nex_state_w_cur_state = torch.sum((self.real_wm_pred_next_state - self.allegro_hand_dof_pos[..., self.sorted_figner_joint_idxes]) ** 2, dim=-1)
            # torch.norm(self.real_wm_pred_next_state - self.allegro_hand_dof_pos, p=2, dim=-1)
            
            # compensator compute finger rew #
            if self.train_action_compensator_w_finger_rew or self.action_compensator_compute_finger_rew:
                sim_finger_trans, sim_finger_rot_quat , sim_finger_trans_per_matrix = self.forward_pk_chain_for_finger_pos(self.allegro_hand_dof_pos)
                
                pred_nex_state_tot = self.allegro_hand_dof_pos.clone()
                pred_nex_state_tot[..., self.sorted_figner_joint_idxes] = self.real_wm_pred_next_state.clone()
                pred_finger_trans, pred_finger_rot_quat , pred_finger_trans_per_matrix = self.forward_pk_chain_for_finger_pos(pred_nex_state_tot)
                diff_finger_trans = torch.norm(sim_finger_trans - pred_finger_trans, p=2, dim=-1) # .mean()
                
                pred_nex_state_wo_compensator_tot = self.allegro_hand_dof_pos.clone()
                pred_nex_state_wo_compensator_tot[..., self.sorted_figner_joint_idxes] = self.real_wm_pred_next_state_orijoints.clone()
                pred_finger_trans_wo_compensator, pred_finger_rot_quat_wo_compensator, pred_finger_trans_per_matrix_wo_compensator = self.forward_pk_chain_for_finger_pos(pred_nex_state_wo_compensator_tot) # sim finger trans - pred finger trans #
                diff_finger_trans_wo_compensator = torch.norm(sim_finger_trans - pred_finger_trans_wo_compensator, p=2, dim=-1) # .mean()
                
                # diff_finger_trans = diff_finger_trans - diff_finger_trans_wo_compensator # the smaller, the better #
                
                # diff_finger_rot_quat = quat_mul(pred_finger_rot_quat, quat_conjugate(sim_finger_rot_quat))
                # dist_finger_rot_quat = 2.0 * torch.asin(torch.clamp(torch.norm(diff_finger_rot_quat[:, 0:3], p=2, dim=-1), max=1.0))
                #### modify fingertip based reward ####
                # diff_real_trans_nex_state_w_cur_state = diff_real_trans_nex_state_w_cur_state + diff_finger_trans # + dist_finger_rot_quat.mean() * 0.03
                
                # sim_finger_trans # nn_envs x nn_fingers x 3
                # pred_finger_trans # nn_envs x nn_fingers x 3
                cross_finger_trans_sim = sim_finger_trans_per_matrix.unsqueeze(2) - sim_finger_trans_per_matrix.unsqueeze(1)
                cross_finger_trans_pred = pred_finger_trans_per_matrix.unsqueeze(2) - pred_finger_trans_per_matrix.unsqueeze(1)
                # nn_envs x nn_fingers x nn_fingers x 3
                diff_cross_finger_trans_sim_pred = torch.norm(cross_finger_trans_sim - cross_finger_trans_pred, p=2, dim=-1)
                diff_cross_finger_trans_sim_pred = diff_cross_finger_trans_sim_pred.sum(dim=-1).sum(dim=-1)
                diff_cross_finger_trans_sim_pred = diff_cross_finger_trans_sim_pred / float(sim_finger_trans_per_matrix.size(1) ** 2 - sim_finger_trans_per_matrix.size(1)) 
                
                if self.train_action_compensator_w_finger_rew:
                    # diff_real_trans_nex_state_w_cur_state = diff_real_trans_nex_state_w_cur_state * 0.0 + diff_finger_trans
                    diff_real_trans_nex_state_w_cur_state = diff_real_trans_nex_state_w_cur_state + 0.5 * diff_finger_trans
                    
                    # print(f"diff_cross_finger_trans_sim_pred: {diff_cross_finger_trans_sim_pred.size()}")
                    diff_real_trans_nex_state_w_cur_state = diff_real_trans_nex_state_w_cur_state + 0.5 *  diff_cross_finger_trans_sim_pred
                    
            
            if self.action_compensator_add_invaction_rew:
                last_hand_qpos = self.obs_buf_lag_history_qpos[:, -1].clone() # (nn_envs, 16)
                unscaled_qpos = unscale(last_hand_qpos, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)
                unscaled_cur_qpos = unscale(self.allegro_hand_dof_pos, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)
                
                input_dict = {
                    'state': unscaled_qpos,
                    'nex_state': unscaled_cur_qpos
                }
                
                with torch.no_grad():
                    pred_action = self.inverse_dynamics_model(input_dict)
                    unscale_compensated_action = unscale(self.cur_compensated_targets.detach(), self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)
                    
                    diff_action_w_compensated_action = torch.norm(pred_action - unscale_compensated_action, p=2, dim=-1)
                    
                diff_real_trans_nex_state_w_cur_state = diff_real_trans_nex_state_w_cur_state + diff_action_w_compensated_action.detach() * 0.001
            
            
            self.rew_buf[:] = -1.0 * diff_real_trans_nex_state_w_cur_state
        else:
            self.rew_buf[:] = compute_hand_reward(
                object_linvel_penalty, self.object_linvel_penalty_scale,
                rotate_reward, self.rotate_reward_scale,
                pose_diff_penalty, self.pose_diff_penalty_scale,
                torque_penalty, self.torque_penalty_scale,
                work_penalty, self.work_penalty_scale,
            )
        
        if self.add_zrot_penalty:
            # use the object rotation to translate the object points --- after that, calculate the penalty value #
            expanded_obj_rot = self.object_rot.contiguous().unsqueeze(1).repeat(1, self.envs_obj_points.size(1), 1).contiguous()
            expanded_obj_rot = expanded_obj_rot.contiguous().view(-1, 4).contiguous()
            
            expanded_obj_poins = self.envs_obj_points.contiguous().view(-1, 3).contiguous()
            expanded_obj_points = quat_apply(expanded_obj_rot, expanded_obj_poins) # object rot --- nn_envs x nn_obj_points x 3
            
            expanded_obj_points = expanded_obj_points.contiguous().view(self.num_envs, self.envs_obj_points.size(1), 3).contiguous() # nn_envs x nn_pts x 3
            
            expanded_obj_points = expanded_obj_points + self.object_pos.unsqueeze(1) # object pos --- nn_envs x nn_obj_points x 3
            # expanded obj points #
            maxx_obj_pts_z = torch.max(expanded_obj_points[:, :, 2], dim=1)[0]
            min_obj_pts_z = torch.min(expanded_obj_points[:, :, 2], dim=1)[0]
            diff_max_min_obj_pts_z = torch.abs(maxx_obj_pts_z - min_obj_pts_z)
            z_penalty = -1.0 * diff_max_min_obj_pts_z
            self.rew_buf[:] = self.rew_buf[:] + z_penalty 
        
        
        
        if self.add_rotp:
            rotp = torch.norm(torch.cross(object_angvel, self.rot_axis_buf), p=2, dim=-1)
            self.rew_buf[:] = self.rew_buf[:] - self.cur_rotp_coef * rotp
        
        
        self.rew_buf_wo_aux[:] = self.rew_buf.clone()
        
        if self.recovery_training or self.add_translation:
            self.rew_buf_wo_aux[:] = dot_linvel_w_target_trans_dir.clone()
        
        if self.grasp_to_grasp:
            self.rew_buf_wo_aux[:] = bonus.float().clone()
            
        if self.hand_tracking:
            self.rew_buf_wo_aux[:] = rotate_reward * self.rotate_reward_scale
        
        if self.train_action_compensator and self.action_compensator_w_obj:
            self.rew_buf_wo_aux[:] = diff_states_w_real_states.clone() # set the aux to hand pose diff #
        
        if self.train_action_compensator_w_real_wm:
            self.rew_buf_wo_aux[:] = rotate_reward * self.rotate_reward_scale
        
        
        if self.add_aux_pose_guidance:
            self.rew_buf += pose_guidance_rew * self.aux_pose_guidance_coef
        
        
        if self.train_goal_conditioned:
            cur_obj_ornt = self.object_pose[..., 3: 7]
            # \delta rot cur_rot = target_rot
            delta_obj_ornt = quat_mul(self.target_obj_pose_buf, quat_conjugate(cur_obj_ornt))
            
            ### add the goal rotation reward ###
            diff_rad = 2.0 * torch.asin(torch.clamp(torch.norm(delta_obj_ornt[..., 0:3], p=2, dim=-1), max=1.0)) 
            diff_rad_rew = 0.2 * (3.14 - diff_rad) 
            diff_rad_bonus = torch.zeros_like(diff_rad_rew)
            diff_rad_bonus = torch.where(diff_rad < 0.1 * 3.1415926, torch.ones_like(diff_rad_bonus), diff_rad_bonus)
            self.rew_buf[:] += diff_rad_rew + diff_rad_bonus
            
            
            dx, dy, dz = get_euler_xyz(delta_obj_ornt)
            delta_rot_euler = torch.stack([dx, dy, dz], dim=-1) # nn_envs x 3 #
            delta_rot_euler = delta_rot_euler / torch.clamp(torch.norm(delta_rot_euler, p=2, dim=-1, keepdim=True), min=1e-6)
            self.rot_axis_buf[:, :] = delta_rot_euler.clone()
            
            
        if self.action_compensator_add_invaction_rew:
            self.rew_buf_aux_pose_guidance[:] = diff_action_w_compensated_action # .detach() * 0.01
            
        
        if self.action_compensator_compute_finger_rew:
            # print(f"diff_finger_trans: {diff_finger_trans.mean()}")
            self.rew_buf_aux_pose_guidance_bonus[:] = (diff_finger_trans - diff_finger_trans_wo_compensator).clone()
            self.rew_buf_aux_pose_guidance[:] = diff_cross_finger_trans_sim_pred.clone()
        
        
        self.reset_buf[:] = self.check_termination(self.object_pos)
        
        if self.hand_tracking and (not self.hand_tracking_nobj):
            obj_falling_mask = self.check_obj_falling_reset(self.object_pos)
            self.rew_buf[obj_falling_mask] = -200.0 # give it a large negative reward #
            
        if self.hand_tracking:
            if self.reset_buf.float().sum() >= self.num_envs // 2:
                cur_reset_rewards = self.rew_buf[self.reset_buf ]
                avg_reset_rewards = torch.mean(cur_reset_rewards).item()
                print(f"Average reset rewards: {avg_reset_rewards}")
        
        if self.add_aux_pose_guidance:
            
            ########## v2 guiding pose ##########
            if self.omni_wrist_ornt:
                rotated_rot_axis_buf = quat_apply(self.rnd_rot_tensor, self.rot_axis_buf)
                cur_obj_guiding_rot = quat_from_angle_axis(  self.guiding_delta_rot_radian * torch.ones((self.num_envs,), dtype=torch.float32).to(self.device) * (self.upd_guiding_pose_steps + 1) , rotated_rot_axis_buf) # guiding delta rot radian
            else:
                cur_obj_guiding_rot = quat_from_angle_axis(  self.guiding_delta_rot_radian * torch.ones((self.num_envs,), dtype=torch.float32).to(self.device) * (self.upd_guiding_pose_steps + 1) , self.rot_axis_buf) # guiding delta rot radian 
            nex_guiding_pose = quat_mul(cur_obj_guiding_rot, self.object_init_state[..., 3:7]) 
            self.upd_guiding_pose_steps[rot_dist < self.rot_radian_threshold] += 1
            ########## v2 guiding pose ##########
            
            self.guiding_pose[rot_dist < self.rot_radian_threshold] = nex_guiding_pose[rot_dist < self.rot_radian_threshold]
        
        
        self.extras['rotation_reward'] = rotate_reward.mean()
        self.extras['object_linvel_penalty'] = object_linvel_penalty.mean()
        self.extras['pose_diff_penalty'] = pose_diff_penalty.mean()
        self.extras['work_done'] = work_penalty.mean()
        self.extras['torques'] = torque_penalty.mean()
        self.extras['roll'] = object_angvel[:, 0].mean()
        self.extras['pitch'] = object_angvel[:, 1].mean()
        self.extras['yaw'] = object_angvel[:, 2].mean()

        if self.evaluate:
            finished_episode_mask = self.reset_buf == 1
            self.stat_sum_rewards += self.rew_buf.sum()
            self.stat_sum_rotate_rewards += rotate_reward.sum()
            self.stat_sum_torques += self.torques.abs().sum()
            self.stat_sum_obj_linvel += (self.object_linvel ** 2).sum(-1).sum()
            self.stat_sum_episode_length += (self.reset_buf == 0).sum()
            self.env_evaluated += (self.reset_buf == 1).sum()
            self.env_timeout_counter[finished_episode_mask] += 1
            info = f'progress {self.env_evaluated} / {self.max_evaluate_envs} | ' \
                   f'reward: {self.stat_sum_rewards / self.env_evaluated:.2f} | ' \
                   f'eps length: {self.stat_sum_episode_length / self.env_evaluated:.2f} | ' \
                   f'rotate reward: {self.stat_sum_rotate_rewards / self.env_evaluated:.2f} | ' \
                   f'lin vel (x100): {self.stat_sum_obj_linvel * 100 / self.stat_sum_episode_length:.4f} | ' \
                   f'command torque: {self.stat_sum_torques / self.stat_sum_episode_length:.2f}'
            tprint(info)
            
            sv_exp_mask = ((self.reset_buf == 1).float() + (self.progress_buf == self.maxx_episode_length).float()) > 1.5
            if sv_exp_mask.sum() > 0 and not self.evaluate_for_statistics:
                reset_env_progress_buf = self.progress_buf[sv_exp_mask]
                reset_env_shadow_hand_dof_pos = self.shadow_hand_dof_pos_buf[sv_exp_mask]
                reset_env_shadow_hand_dof_tars = self.shadow_hand_dof_tars_buf[sv_exp_mask]
                reset_env_object_pose = self.object_pose_buf[sv_exp_mask]
                reset_env_extrin = self.extrin_buf[sv_exp_mask]
                reset_env_rot_axis = self.rot_axis_totep_buf[sv_exp_mask]
                reset_reward_buf = self.reward_buf[sv_exp_mask]
                reset_value_buf = self.value_buf[sv_exp_mask]
                reset_hand_root_ornt = self.hand_root_ornt_buf[sv_exp_mask]
                sv_dict = {
                    'progress_buf': reset_env_progress_buf.detach().cpu().numpy(),
                    'shadow_hand_dof_pos': reset_env_shadow_hand_dof_pos.detach().cpu().numpy(),
                    'shadow_hand_dof_tars': reset_env_shadow_hand_dof_tars.detach().cpu().numpy(),
                    'object_pose': reset_env_object_pose.detach().cpu().numpy(),
                    'extrin': reset_env_extrin.detach().cpu().numpy(),
                    'rot_axis': reset_env_rot_axis.detach().cpu().numpy(),
                    'reward_buf': reset_reward_buf.detach().cpu().numpy(),
                    'value_buf': reset_value_buf.detach().cpu().numpy(),
                    'hand_root_ornt': reset_hand_root_ornt.detach().cpu().numpy(),
                }

                sv_dict_root = "."
                sv_dict_folder = f"cache/eval_res_{self.hand_type}_{self.object_type}_{self.rot_axis}_{self.rot_axis_mult}_facing_{self.hand_facing_dir}_{self.grasp_cache_name}_m{self.upper_mass_limit}"
                
                if len(self.additional_tag) > 0:
                    sv_dict_folder += f"_{self.additional_tag}"
                
                sv_dict_folder = os.path.join(sv_dict_root, sv_dict_folder) 
                os.makedirs(sv_dict_folder, exist_ok=True)
                sv_dict_fn = os.path.join(sv_dict_folder, f"{self.grasp_cache_name}_eval_res_{self.sv_cache_nn}.npy")
                np.save(sv_dict_fn, sv_dict)
                self.sv_cache_nn += 1
                tprint(f"Saved {sv_exp_mask.sum()} samples to {sv_dict_fn}")
                
                if self.openloop_replay:
                    exit()
            
            if self.evaluate_for_statistics:
                
                if self.evaluate_goal_conditioned:
                    finished_episode_for_eval_exp_mask = ((self.reset_buf == 1).float() + (self.progress_buf == self.maxx_episode_length).float()) > 1.5
                    # finished_episode_for_eval_exp_mask = ((self.reset_buf == 1).float() + (self.progress_buf >= 100).float()) > 1.5
                else:
                    # finished_episode_for_eval_exp_mask = ((self.reset_buf == 1).float() + (self.progress_buf >= 200).float()) > 1.5
                    finished_episode_for_eval_exp_mask = ((self.reset_buf == 1).float() + (self.progress_buf >= 100).float()) > 1.5
                
                reset_envs_progress_buf = self.progress_buf[finished_episode_for_eval_exp_mask].detach() # .cpu().numpy().tolist()
                reset_envs_ep_rotr = self.ep_rotr_buf[finished_episode_for_eval_exp_mask].sum(dim=-1).sum(dim=-1).detach() #.cpu().numpy().tolist()
                reset_envs_ep_rotp = self.ep_rotp_buf[finished_episode_for_eval_exp_mask].sum(dim=-1).sum(dim=-1).detach() #.cpu().numpy().tolist()
                reset_envs_avgep_rotp = reset_envs_ep_rotp / reset_envs_progress_buf.float()
                reset_envs_ep_rew = self.reward_buf[finished_episode_for_eval_exp_mask].sum(dim=-1).sum(dim=-1).detach()
                
                self.evaluated_progress_length.extend(reset_envs_progress_buf.cpu().tolist())
                self.evaluated_ep_rotr.extend(reset_envs_ep_rotr.cpu().tolist())
                self.evaluated_ep_rotp.extend(reset_envs_avgep_rotp.cpu().tolist())
                self.evaluated_ep_rew.extend(reset_envs_ep_rew.cpu().tolist())
                
                self.ep_rotr_buf[finished_episode_for_eval_exp_mask] = 0.0
                self.ep_rotp_buf[finished_episode_for_eval_exp_mask] = 0.0
                self.reward_buf[finished_episode_for_eval_exp_mask] = 0.0
                
                if self.evaluate_goal_conditioned and torch.sum(finished_episode_for_eval_exp_mask.float()).item() >= 1:
                    # evaluatte goal conditioned #
                    # calculate the distance between all object rot quat in the full episode and the target object rot quat #
                    obj_pose_buf_finished_episode_for_eval = self.object_pose_buf[finished_episode_for_eval_exp_mask].detach()
                    obj_ornt_buf_finished_episode_for_eval = obj_pose_buf_finished_episode_for_eval[..., 3: 7] # (nn_finished_envs, 4)
                    # pos buf finisehd episode for eval #
                    target_obj_pose = self.target_obj_pose_buf[finished_episode_for_eval_exp_mask].detach() # (nn_finished_envs, 4)
                    # calculate the diff between target obj pose and the obj pose in the episode #
                    target_obj_pose = target_obj_pose.contiguous().unsqueeze(1).repeat(1, obj_ornt_buf_finished_episode_for_eval.size(1), 1).contiguous()
                    
                    # Extract target object orientation (quaternion)
                    # target_obj_ornt = target_obj_pose[:, 3:7]  # (nn_finished_envs, 4)
                    
                    # Calculate orientation difference in radians
                    # Convert quaternions to axis-angle representation and get the angle
                    def quaternion_angle_diff(q1, q2):
                        # Normalize quaternions
                        # q1_norm = q1 / torch.norm(q1, dim=-1, keepdim=True)
                        # q2_norm = q2 / torch.norm(q2, dim=-1, keepdim=True)
                        q1_norm = q1[..., [3, 0, 1, 2]]; q2_norm = q2[..., [3, 0, 1, 2]]; 
                        
                        # Calculate the relative quaternion (q1 * q2_conjugate)
                        q1_conj = torch.cat([q1_norm[..., :1], -q1_norm[..., 1:]], dim=-1)
                        q_rel = quaternion_multiply(q1_norm, q1_conj)
                        
                        # Extract the angle from the quaternion (2 * acos(w))
                        # The angle is 2 * acos(|w|) where w is the real part
                        w = torch.abs(q_rel[..., 0])
                        angle = 2.0 * torch.acos(torch.clamp(w, 0.0, 1.0))
                        
                        return angle
                    
                    def quaternion_multiply(q1, q2):
                        # q1 * q2
                        w1, x1, y1, z1 = q1[..., 0], q1[..., 1], q1[..., 2], q1[..., 3]
                        w2, x2, y2, z2 = q2[..., 0], q2[..., 1], q2[..., 2], q2[..., 3]
                        
                        w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
                        x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
                        y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
                        z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
                        
                        return torch.stack([w, x, y, z], dim=-1)
                    
                    
                    
                    
                    if self.train_goal_conditioned:
                        quat_diff = quat_mul(self.object_pose[finished_episode_for_eval_exp_mask, 3: 7], quat_conjugate(target_obj_pose[:, 0, :]))
                        orientation_diff_rad = 2.0 * torch.asin(torch.clamp(torch.norm(quat_diff[..., 0:3], p=2, dim=-1), max=1.0)) 
                        succ = (orientation_diff_rad <= 0.1 * 3.1415926).float()
                    else:
                        # quat diff
                        quat_diff = quat_mul(obj_ornt_buf_finished_episode_for_eval, quat_conjugate(target_obj_pose))
                        orientation_diff_rad = 2.0 * torch.asin(torch.clamp(torch.norm(quat_diff[..., 0:3], p=2, dim=-1), max=1.0)) 
                        # print(f"orientation_diff_rad: {orientation_diff_rad}")
                        # Calculate orientation difference in radians
                        # orientation_diff_rad = quaternion_angle_diff(obj_ornt_buf_finished_episode_for_eval, target_obj_pose)
                        succ = (torch.sum((orientation_diff_rad <= 0.1 * 3.1415926).float(), dim=-1) >= 1).float()
                    
                    minn_orientation_diff_rad, _ = torch.min(orientation_diff_rad, dim=-1) # 
                    avg_minn_orientation_diff_rad = torch.mean(minn_orientation_diff_rad).item()
                    # print(f"orientation diff rad: {avg_minn_orientation_diff_rad}")
                    
                    # Store the orientation difference for evaluation
                    self.evaluated_orientation_diff.extend(succ.cpu().tolist())
            

            if self.env_evaluated >= self.max_evaluate_envs:
                
                if self.evaluate_for_statistics:
                    self.evaluated_progress_length = np.array(self.evaluated_progress_length)
                    self.evaluated_ep_rotr = np.array(self.evaluated_ep_rotr)
                    self.evaluated_ep_rotp = np.array(self.evaluated_ep_rotp)
                    self.evaluated_orientation_diff = np.array(self.evaluated_orientation_diff)
                    self.evaluated_ep_rew = np.array(self.evaluated_ep_rew)
                    
                    self.evaluated_progress_length_time = 20.0 * (self.evaluated_progress_length / 400.0)
                    avg_eptime, std_eptime = np.mean(self.evaluated_progress_length_time), np.std(self.evaluated_progress_length_time)
                    avg_ep_rotr, std_ep_rotr = np.mean(self.evaluated_ep_rotr), np.std(self.evaluated_ep_rotr)
                    avg_ep_rotp, std_ep_rotp = np.mean(self.evaluated_ep_rotp), np.std(self.evaluated_ep_rotp)
                    avg_ep_rew, std_ep_rew = np.mean(self.evaluated_ep_rew), np.std(self.evaluated_ep_rew)
                    
                    # avg orientation diff #
                    avg_orientation_diff, std_orientation_diff = np.mean(self.evaluated_orientation_diff), np.std(self.evaluated_orientation_diff)
                    print(f"avg_eptime: {avg_eptime}, std_eptime: {std_eptime}, avg_ep_rotr: {avg_ep_rotr}, std_ep_rotr: {std_ep_rotr}, avg_ep_rotp: {avg_ep_rotp}, std_ep_rotp: {std_ep_rotp}, avg_orientation_diff: {avg_orientation_diff:.4f}, std_orientation_diff: {std_orientation_diff:.4f}, avg_ep_rew: {avg_ep_rew:.4f}, std_ep_rew: {std_ep_rew:.4f}")
                    
                    # np.save(os.path.join(sv_dict_root, f"eval_res_{self.hand_type}_{self.object_type}_{self.rot_axis}_{self.rot_axis_mult}_facing_{self.hand_facing_dir}_{self.grasp_cache_name}_m{self.upper_mass_limit}_per_env.npy"), {
                    #     'progress_length': self.evaluated_progress_length,
                    #     'ep_rotr': self.evaluated_ep_rotr,
                    #     'ep_rotp': self.evaluated_ep_rotp,
                    # })
                exit()
    

    def post_physics_step(self):
        self.progress_buf += 1
        self.reset_buf[:] = 0
        self._refresh_gym()
        
        if self.train_action_compensator_w_real_wm: # train jompensator w real wm #
            with torch.enable_grad():
                if self.train_action_compensator_w_real_wm_multi_compensator:
                    tot_delta_action_model_training_loss = self.train_multi_delta_action_model()
                else:
                    tot_delta_action_model_training_loss = self.train_delta_action_model()
            if self.progress_buf[0].item() % 8 == 0:    
                print(f"delta action model training loss: {tot_delta_action_model_training_loss}")
                # avg_pre_action_scale_othersim_data_buf = torch.mean(pre_action_scale_othersim_data_buf, dim=0)
                # print(f"pre_action_scale_othersim_data_buf: {avg_pre_action_scale_othersim_data_buf}")
        
        self.compute_reward(self.actions)
        env_ids = self.reset_buf.nonzero(as_tuple=False).squeeze(-1)
        if len(env_ids) > 0:
            self.reset_idx(env_ids)
        self.compute_observations()
        
        # self.viewer = None # # # # viewer #

        if self.viewer and self.debug_viz:
            # draw axes on target object
            self.gym.clear_lines(self.viewer)
            self.gym.refresh_rigid_body_state_tensor(self.sim)

            for i in range(self.num_envs):
                objectx = (self.object_pos[i] + quat_apply(self.object_rot[i], to_torch([1, 0, 0], device=self.device) * 0.2)).cpu().numpy()
                objecty = (self.object_pos[i] + quat_apply(self.object_rot[i], to_torch([0, 1, 0], device=self.device) * 0.2)).cpu().numpy()
                objectz = (self.object_pos[i] + quat_apply(self.object_rot[i], to_torch([0, 0, 1], device=self.device) * 0.2)).cpu().numpy()

                p0 = self.object_pos[i].cpu().numpy()
                self.gym.add_lines(self.viewer, self.envs[i], 1, [p0[0], p0[1], p0[2], objectx[0], objectx[1], objectx[2]], [0.85, 0.1, 0.1])
                self.gym.add_lines(self.viewer, self.envs[i], 1, [p0[0], p0[1], p0[2], objecty[0], objecty[1], objecty[2]], [0.1, 0.85, 0.1])
                self.gym.add_lines(self.viewer, self.envs[i], 1, [p0[0], p0[1], p0[2], objectz[0], objectz[1], objectz[2]], [0.1, 0.1, 0.85])

    def _create_ground_plane(self):
        plane_params = gymapi.PlaneParams()
        plane_params.normal = gymapi.Vec3(0.0, 0.0, 1.0)
        self.gym.add_ground(self.sim, plane_params)

    def pre_physics_step(self, actions):
        
        if self.train_action_compensator and self.fingertip_only_action_compensator:
            actions_full = torch.zeros((actions.size(0), 16), dtype=torch.float32).to(actions.device)
            actions_full[..., [3, 7, 11, 15]] = actions.clone()
            actions = actions_full.clone() # get actions and actios full #
        
        self.actions = actions[..., : self.num_allegro_hand_dofs].clone().to(self.device)
        # print(f"num_allegro_hand_dofs: {self.num_allegro_hand_dofs}, prev_targets: {self.prev_targets.size()}, actions: {self.actions.size()}")
        
        if self.evaluate and self.evaluate_action_add_noise:
            self.actions = self.actions + torch.randn_like(self.actions) * self.evaluate_action_noise_std
        
        if self.tune_bc_model:
            self.actions = scale(self.actions, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)
            # maxx_actions = torch.max(torch.abs(self.actions), dim=-1)[0]
            # print(f"tune bc model, maxx_actions: {maxx_actions}")
            
            if self.tune_bc_via_compensator_model:
                self.actions = self.bc_model_actions + actions * self.delta_action_scale #   1. / 48 #  1. /24
            
            targets = self.actions
        else:
            targets = self.prev_targets + 1 / 24 * self.actions
        
        if self.train_action_compensator and (not self.action_compensator_not_using_real_actions):
            cur_preset_actions = batched_index_select(self.envs_replay_qtars, indices=self.progress_buf.unsqueeze(1), dim=1).squeeze(1) # nn_envs x 16
            targets = cur_preset_actions + actions[..., :self.num_allegro_hand_dofs].clone().to(self.device) * self.delta_action_scale #   *1/48  # * 1/96 #  1/48 # 1 / 24 # 1 / 12 #  1 / 24
        
        # 
        
        if self.delta_actions is not None:
            targets = targets + self.delta_actions  # * 1/ 6 
            
            
        if self.finetune_with_action_compensator:
            delta_actions = self._inference_act_compensator(self.obs_buf, targets)
            # ori_targets = targets.clone()
            if self.fingertip_only_action_compensator:
                self.compensated_targets = targets.clone()
                self.compensated_targets[..., [3, 7, 11, 15]] = self.compensated_targets[..., [3, 7, 11, 15]] +  delta_actions * self.delta_action_scale
                # self.compensated_targets = targets + delta_actions * self.delta_action_scale #  *1/48  #  * 1/96  # *1/48  #  *  1 / 24 # *  1/12 #  1 / 6
            else:
                self.compensated_targets = targets + delta_actions * self.delta_action_scale #  *1/48  #  * 1/96  # *1/48  #  *  1 / 24 # *  1/12 #  1 / 6
            
        if self.openloop_replay:
            targets = batched_index_select(self.openloop_replay_src_actions, indices=self.progress_buf.unsqueeze(1), dim=1).squeeze(1) # nn_envs x 16
        
        if self.train_action_compensator_w_real_wm and self.use_bc_base_policy:
            targets = self.bc_policy_pred_targets
        
        self.cur_targets[:] = tensor_clamp(targets, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)
        
        
        
        if self.train_action_compensator_w_real_wm:
            # if self.hierarchical_compensator:
            #     self.cur_targets = self.first_level_compensated_targets
                
            cur_compensated_targets = self.cur_targets.clone()
            # print(f"cur_compensated_targets: {cur_compensated_targets.shape}, compensating_targets: {self.compensating_targets.shape}")
            
            if self.hierarchical_compensator:
                cur_compensated_targets[..., self.compensator_output_joint_idxes] = self.first_level_compensated_targets[..., self.compensator_output_joint_idxes] + self.compensating_targets * self.delta_action_scale #  1/ 24
            else:
                cur_compensated_targets[..., self.compensator_output_joint_idxes] = self.cur_targets[..., self.compensator_output_joint_idxes] + self.compensating_targets * self.delta_action_scale #  1/ 24
            
            abs_compensated_targets = torch.abs(self.compensating_targets * self.delta_action_scale).mean(dim=0)
            # if self.progress_buf[0].item() % 8 == 0:
            #     print(f"abs_compensated_targets: {abs_compensated_targets}")
            
            # 
            
            # print(f"compensator_output_joint_idxes: {self.compensator_output_joint_idxes}")
            self.cur_compensated_targets = cur_compensated_targets
            if self.action_compensator_w_full_hand:
                self._get_full_hand_world_model_prediction(cur_compensated_targets)
                self._get_full_hand_world_model_prediction(self.cur_targets, use_compensated_targets=False)
            elif self.action_compensator_input_finger_idx == -1:
                if self.wm_per_joint_compensator_full_hand:
                    self.real_wm_pred_next_state = self.allegro_hand_dof_pos.clone()
                    self.real_wm_pred_next_state_orijoints = self.allegro_hand_dof_pos.clone()
                else:
                    self._get_finger_world_model_prediction_perjoint(cur_compensated_targets)
                    self._get_finger_world_model_prediction_perjoint(self.cur_targets, use_compensated_targets=False)
            else:
                self._get_finger_world_model_prediction(cur_compensated_targets)
                self._get_finger_world_model_prediction(self.cur_targets, use_compensated_targets=False)
        
        
        
        if self.real_to_sim_auto_tune:
            cur_preset_actions = self.auto_tune_actions[self.testing_traj_idx, self.testing_traj_ts, :].clone().unsqueeze(0).repeat(self.num_envs, 1).contiguous()
            self.cur_targets[:, :] = cur_preset_actions
            self.testing_traj_ts += 1
        
        self.prev_targets[:] = self.cur_targets.clone()
        self.object_rot_prev[:] = self.object_rot
        self.object_pos_prev[:] = self.object_pos
        
        
        # if self.hand_tracking:
        #     # self.target_dof_state[:, 0] = self.hand_tracking_targets
        #     self.target_dof_state[:, 0 ] = self.hand_tracking_targets.contiguous().view(-1).contiguous() # hand tracking targets #
        #     self.target_dof_state[:, 1] = 0.0 
        #     self.gym.set_dof_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self.target_dof_state), gymtorch.unwrap_tensor(self.target_hand_indices), len(self.target_hand_indices)) # set dof state tensor for the target hand 
        #     self.gym.set_dof_position_target_tensor_indexed(self.sim,
        #                                                 gymtorch.unwrap_tensor(self.hand_tracking_targets),
        #                                                 gymtorch.unwrap_tensor(self.target_hand_indices), len(self.target_hand_indices))
        

        if self.force_scale > 0.0:
            self.rb_forces *= torch.pow(self.force_decay, self.dt / self.force_decay_interval)
            # apply new forces
            obj_mass = to_torch(
                [self.gym.get_actor_rigid_body_properties(env, self.gym.find_actor_handle(env, 'object'))[0].mass for
                 env in self.envs], device=self.device)
            prob = self.random_force_prob_scalar
            force_indices = (torch.less(torch.rand(self.num_envs, device=self.device), prob)).nonzero()
            self.rb_forces[force_indices, self.object_rb_handles, :] = torch.randn(
                self.rb_forces[force_indices, self.object_rb_handles, :].shape,
                device=self.device) * obj_mass[force_indices, None] * self.force_scale
            if self.sim_gravity_via_force:
                # print(f"adding gravity force via force")
                gravity_dir = torch.tensor([0, 0, -1], device=self.device).float()
                self.rb_forces[:, self.object_rb_handles, :] += gravity_dir.unsqueeze(0).unsqueeze(0) * obj_mass.unsqueeze(1).unsqueeze(-1) * self.cur_gravity_force # 
                
            self.gym.apply_rigid_body_force_tensors(self.sim, gymtorch.unwrap_tensor(self.rb_forces), None, gymapi.ENV_SPACE)

        elif self.sim_gravity_via_force:
            # print(f"adding gravity force via force")
            obj_mass = to_torch(
                [self.gym.get_actor_rigid_body_properties(env, self.gym.find_actor_handle(env, 'object'))[0].mass for
                 env in self.envs], device=self.device)
            gravity_dir = torch.tensor([0, 0, -1], device=self.device).float()
            self.rb_forces[:, self.object_rb_handles, :] = gravity_dir.unsqueeze(0).unsqueeze(0) * obj_mass.unsqueeze(1).unsqueeze(-1) * self.cur_gravity_force
            self.gym.apply_rigid_body_force_tensors(self.sim, gymtorch.unwrap_tensor(self.rb_forces), None, gymapi.ENV_SPACE)
        
        if self.apply_obj_virtual_force:
            # self: Gym, sim: Sim, forceTensor: Tensor = None, torqueTensor: Tensor = None, space: CoordinateSpace = CoordinateSpace.ENV_SPACE
            virtual_force_scale = 0.5 #  0.1 #
            self.rb_forces[:, self.object_rb_handles, :] = actions[..., self.num_allegro_hand_dofs: ].clone().to(self.device).unsqueeze(1) * virtual_force_scale
            self.gym.apply_rigid_body_force_tensors(self.sim, None, gymtorch.unwrap_tensor(self.rb_forces), gymapi.ENV_SPACE)
    

    def reset(self):
        super().reset()
        self.obs_dict['priv_info'] = self.priv_info_buf.to(self.rl_device)
        self.obs_dict['proprio_hist'] = self.proprio_hist_buf.to(self.rl_device)
        return self.obs_dict

    def step(self, actions):
        
        # print(f"progress: {self.progress_buf[0]}")
        
        super().step(actions)
        self.obs_dict['priv_info'] = self.priv_info_buf.to(self.rl_device)
        self.obs_dict['proprio_hist'] = self.proprio_hist_buf.to(self.rl_device)
        
        
        if self.train_action_compensator_w_real_wm:
            wm_qpos_buf_noised = self.obs_buf_lag_history_qpos[:, - self.compensator_history_length: ].clone()
            
            wm_qpos_buf_noised =  (torch.rand(wm_qpos_buf_noised.shape).to(self.device) * 2.0 - 1.0) * self.joint_noise_scale + wm_qpos_buf_noised
            
            wm_qtars_buf = self.obs_buf_lag_history_qtars[:, - self.compensator_history_length: ].clone()
            
            wm_qpos_buf_noised = unscale(
                wm_qpos_buf_noised, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits
            )
            
            if self.use_masked_action_compensator:
                wm_input_obs = torch.cat(
                    [ wm_qpos_buf_noised, wm_qtars_buf ], dim=-1
                )
            else:
                wm_input_obs = torch.cat(
                    [ wm_qpos_buf_noised[..., self.compensator_input_joint_idxes], wm_qtars_buf[..., self.compensator_input_joint_idxes] ], dim=-1
                )
            
            wm_input_obs = wm_input_obs.contiguous().view(wm_input_obs.shape[0], -1).contiguous()
            
            self.obs_dict['compensator_obs'] = wm_input_obs
            
            # self.obs_dict['compensator_obs'] = self.obs_buf[..., : 32 * self.wm_history_length]
        
        return self.obs_dict, self.rew_buf, self.reset_buf, self.extras

    
    def update_low_level_control(self):
        previous_dof_pos = self.allegro_hand_dof_pos.clone()
        self._refresh_gym()
        if self.torque_control:
            dof_pos = self.allegro_hand_dof_pos
            dof_vel = (dof_pos - previous_dof_pos) / self.dt
            self.dof_vel_finite_diff = dof_vel.clone()
            if self.finetune_with_action_compensator:
                torques = self.p_gain * (self.compensated_targets - dof_pos) - self.d_gain * dof_vel
            else:
                torques = self.p_gain * (self.cur_targets - dof_pos) - self.d_gain * dof_vel
            self.torques = torch.clip(torques, -0.5, 0.5).clone()
            self.gym.set_dof_actuation_force_tensor(self.sim, gymtorch.unwrap_tensor(self.torques))
        else:
            if self.finetune_with_action_compensator:
                self.gym.set_dof_position_target_tensor(self.sim, gymtorch.unwrap_tensor(self.compensated_targets))
            else:
                self.gym.set_dof_position_target_tensor(self.sim, gymtorch.unwrap_tensor(self.cur_targets))

    def check_termination(self, object_pos):
        
        if self.recovery_training:
            diff_cur_obj_pos_w_target = torch.norm(object_pos - self.target_root_state_tensor[self.object_indices, :3], p=2, dim=-1)
            resets = torch.logical_or(
                torch.less(diff_cur_obj_pos_w_target, self.recovery_succ_obj_pos_thres),
                torch.greater_equal(self.progress_buf, self.max_episode_length),
            )
            resets = torch.logical_or(resets, 
                                      torch.less(object_pos[:, -1], self.reset_z_threshold),
                                      )
            return resets
        
        if self.real_to_sim_auto_tune:
            if self.testing_traj_ts >= self.auto_tune_states.size(1):
                resets = torch.ones((self.num_envs, ), dtype=torch.float32).to(self.device).bool()
                self.testing_traj_idx += 1 # increase the testing traj idx #
                print(f"Testing {self.testing_traj_idx}-th trajectory")
            else:
                resets = torch.zeros((self.num_envs, ), dtype=torch.float32).to(self.device).bool()
            
            return resets
        
        if self.openloop_replay or self.train_action_compensator_w_real_wm:
            resets = torch.greater_equal(self.progress_buf, self.max_episode_length)
            return resets
        
        
        if self.train_action_compensator_w_real_wm and self.train_action_compensator_free_hand:
            resets = torch.greater_equal(self.progress_buf, self.max_episode_length)
            return resets
        
        if  self.train_action_compensator:
            if self.action_compensator_w_obj:
                resets = torch.logical_or(
                    torch.less(object_pos[:, -1], self.reset_z_threshold),
                    torch.greater_equal(self.progress_buf, self.max_episode_length),
                )
            else:
                resets = torch.greater_equal(self.progress_buf, self.max_episode_length - 4)
            return resets
        
        if self.omni_wrist_ornt:
            init_obj_pos = self.init_root_state_tensor[self.object_indices, :3]
            diff_curr_obj_pos_w_init = torch.norm(object_pos - init_obj_pos, p=2, dim=-1)
            resets = torch.logical_or(
                torch.greater(diff_curr_obj_pos_w_init, self.reset_pos_dist_val),
                torch.greater_equal(self.progress_buf, self.max_episode_length),
            )
        else:
            if self.hand_tracking and self.hand_tracking_nobj:
                resets = torch.greater_equal(self.progress_buf, self.max_episode_length)
            else:
                resets = torch.logical_or(
                    torch.less(object_pos[:, -1], self.reset_z_threshold),
                    torch.greater_equal(self.progress_buf, self.max_episode_length),
                )
        

        if self.add_tilding_termination:
            keypoint_A = self.keypoints[0].unsqueeze(0).repeat(self.num_envs, 1)
            keypoint_B = self.keypoints[1].unsqueeze(0).repeat(self.num_envs, 1)
            keypoint_A_z = quat_apply(self.object_rot, keypoint_A)[..., 2]
            keypoint_B_z = quat_apply(self.object_rot, keypoint_B)[..., 2]
            diff_z = torch.abs(keypoint_A_z - keypoint_B_z)
            dist_A_B = torch.norm(keypoint_A - keypoint_B, p=2, dim=-1)
            threshould_A_B = dist_A_B * math.sin(self.tilding_degree_threshold / float(180) * np.pi)
            # if dist_z > threshould_A_B: #
            resets = torch.logical_or(resets, torch.greater_equal(diff_z, threshould_A_B))
            
            
        return resets
    
    def check_obj_falling_reset(self, object_pos):
        if self.omni_wrist_ornt:
            init_obj_pos = self.init_root_state_tensor[self.object_indices, :3]
            diff_curr_obj_pos_w_init = torch.norm(object_pos - init_obj_pos, p=2, dim=-1)
            obj_falling_resets = torch.logical_and(
                torch.greater(diff_curr_obj_pos_w_init, self.reset_pos_dist_val),
                torch.less(self.progress_buf, self.max_episode_length),
            )
        else:
            obj_falling_resets = torch.logical_and(
                torch.less(object_pos[:, -1], self.reset_z_threshold),
                torch.less(self.progress_buf, self.max_episode_length),
            )
        return obj_falling_resets
    
    def _refresh_gym(self):
        self.gym.refresh_dof_state_tensor(self.sim)
        self.gym.refresh_actor_root_state_tensor(self.sim)
        self.gym.refresh_rigid_body_state_tensor(self.sim)
        self.gym.refresh_net_contact_force_tensor(self.sim)
        if self.add_force_obs:
            self.gym.refresh_force_sensor_tensor(self.sim)
            self.gym.refresh_dof_force_tensor(self.sim)
        self.object_pose = self.root_state_tensor[self.object_indices, 0:7]
        self.object_pos = self.root_state_tensor[self.object_indices, 0:3]
        self.object_rot = self.root_state_tensor[self.object_indices, 3:7]
        self.object_linvel = self.root_state_tensor[self.object_indices, 7:10]
        self.object_angvel = self.root_state_tensor[self.object_indices, 10:13]
        
        # print(f"object_pose: {self.object_pose}")

    def _setup_domain_rand_config(self, rand_config):
        self.randomize_mass = rand_config['randomizeMass']
        self.randomize_mass_lower = rand_config['randomizeMassLower']
        self.randomize_mass_upper = rand_config['randomizeMassUpper']
        self.randomize_com = rand_config['randomizeCOM']
        self.randomize_com_lower = rand_config['randomizeCOMLower']
        self.randomize_com_upper = rand_config['randomizeCOMUpper']
        self.randomize_friction = rand_config['randomizeFriction']
        self.randomize_friction_lower = rand_config['randomizeFrictionLower']
        self.randomize_friction_upper = rand_config['randomizeFrictionUpper']
        self.randomize_scale = rand_config['randomizeScale']
        self.scale_list_init = rand_config['scaleListInit']
        self.randomize_scale_list = rand_config['randomizeScaleList']
        self.randomize_scale_lower = rand_config['randomizeScaleLower']
        self.randomize_scale_upper = rand_config['randomizeScaleUpper']
        self.randomize_pd_gains = rand_config['randomizePDGains']
        self.randomize_p_gain_lower = rand_config['randomizePGainLower']
        self.randomize_p_gain_upper = rand_config['randomizePGainUpper']
        self.randomize_d_gain_lower = rand_config['randomizeDGainLower']
        self.randomize_d_gain_upper = rand_config['randomizeDGainUpper']
        self.joint_noise_scale = rand_config['jointNoiseScale']

    def _setup_priv_option_config(self, p_config):
        self.enable_priv_obj_position = p_config['enableObjPos']
        self.enable_priv_obj_mass = p_config['enableObjMass']
        self.enable_priv_obj_scale = p_config['enableObjScale']
        self.enable_priv_obj_com = p_config['enableObjCOM']
        self.enable_priv_obj_friction = p_config['enableObjFriction']

    def _update_priv_buf(self, env_id, name, value, lower=None, upper=None):
        # normalize to -1, 1
        s, e = self.priv_info_dict[name]
        if eval(f'self.enable_priv_{name}'):
            if type(value) is list:
                value = to_torch(value, dtype=torch.float, device=self.device)
            if type(lower) is list or upper is list:
                lower = to_torch(lower, dtype=torch.float, device=self.device)
                upper = to_torch(upper, dtype=torch.float, device=self.device)
            if lower is not None and upper is not None:
                value = (2.0 * value - upper - lower) / (upper - lower)
            self.priv_info_buf[env_id, s:e] = value
        else:
            self.priv_info_buf[env_id, s:e] = 0

    def _setup_object_info(self, o_config):
        
        self.use_multi_objs = self.config['env'].get('useMultiObjs', False)
        
        self.object_type = o_config['type']
        raw_prob = o_config['sampleProb']
        
        specified_obj_idx = o_config['specifiedObjectIdx']
        specified_obj_idx = str(specified_obj_idx)
        
        if len(specified_obj_idx) == 0:
            if "_" in self.object_type:
                main_type, subset_type = self.object_type.split("_")
                cuboids = sorted(glob(f'assets/{main_type}/{subset_type}/*.urdf'))
            else: # no subset type #
                cuboids = sorted(glob(f'assets/{self.object_type}/*.urdf'))
                
            if self.object_type == 'dexenv':
                dexenv_forbid_idxes = [46, 49, 52, 66, 68, 56, 10, 102, 117]
                specified_obj_idx = [ _ for _ in range(len(cuboids)) if _ not in dexenv_forbid_idxes ]
                
                # ### keeping obj names ### #
                # specified_obj_idx = [20, 21, 22, 25, 40, 50, 51, 53, 65, 67, 70, 33, 58, 38, 31, 60, 61, 62, 75, 77, 78, 11, 12, 27, 29, 35, 36, 37, 55, 88, 89, 94, 96, 105, 110, 113, 115, 118]
                
            elif self.object_type == 'grab':
                # specified_obj_idx = [0, 4, 5, 6, 9, 10, 11, 12, 13, 14, 15, 18, 19, 22, 24, 29, 30, 31, 33, 34, 40, 42, 46, 47, 48, 49]
                specified_obj_idx = [2, 10, 15, 16, 21, 28, 39, 47]
            else:
                specified_obj_idx = [ _ for _ in range(len(cuboids)) ]
            
            specified_obj_idx = [str(_) for _ in specified_obj_idx]
            specified_obj_idx = "AND".join(specified_obj_idx)
            self.config['env']['object']['specifiedObjectIdx'] = specified_obj_idx
            
        
        if self.use_multi_objs: # and multiple objs #
            self.multi_objs_specified_obj_idx = str(specified_obj_idx).split('ANDOBJ')
            self.multi_objs_specified_obj_idx = [
                cur_obj_specified_obj_idx.split('AND') for cur_obj_specified_obj_idx in self.multi_objs_specified_obj_idx
            ]
            self.multi_objs_specified_obj_idx = [
                [int(i) for i in cur_obj_specified_obj_idx] for cur_obj_specified_obj_idx in self.multi_objs_specified_obj_idx
            ]
            tot_nn_insts = sum( [ len(cur_obj_idx_list) for cur_obj_idx_list in self.multi_objs_specified_obj_idx ] )
            # self.specified_obj_idx = [ _ for _ in range(tot_nn_insts) ]
            self.obj_idx_to_obj_type = {}
            per_obj_type_nn_insts = [ len(cur_obj_idx_list ) for cur_obj_idx_list in self.multi_objs_specified_obj_idx ]
            
            self.specified_obj_idx = []
            
            for i_obj_type, cur_obj_inst_list in enumerate(self.multi_objs_specified_obj_idx):
                cur_accumulated_nn_objs_insts = sum(per_obj_type_nn_insts[: i_obj_type])
                for cur_obj_inst_idx in cur_obj_inst_list:
                    self.obj_idx_to_obj_type[cur_accumulated_nn_objs_insts + cur_obj_inst_idx] = i_obj_type
                # for i_obj_inst_idx, cur_obj_inst_idx in enumerate(cur_obj_inst_list):
                #     self.obj_idx_to_obj_type[cur_accumulated_nn_objs_insts + i_obj_inst_idx] = i_obj_type
                self.specified_obj_idx.append(cur_accumulated_nn_objs_insts + cur_obj_inst_idx)
            print(f"self.obj_idx_to_obj_type: {self.obj_idx_to_obj_type}")
        else:
            specified_obj_idx = str(specified_obj_idx)
            
            if len(specified_obj_idx) == 0:
                specified_obj_idx = []
            else:
                specified_obj_idx = specified_obj_idx.split('AND')
                specified_obj_idx = [int(i) for i in specified_obj_idx]
        
            self.specified_obj_idx = specified_obj_idx
            self.obj_idx_to_obj_type = {i: 0 for i in specified_obj_idx}
            
            
        assert (sum(raw_prob) == 1)

        if self.use_multi_objs:
            primitive_list = self.object_type.split('ANDOBJ')
            raw_prob = [ 1.0 / len(primitive_list) for _ in primitive_list ]
        else:
            primitive_list = self.object_type.split('+')
        print('---- Primitive List ----')
        print(primitive_list)
        self.object_type_prob = []
        self.object_type_list = []
        self.asset_files_dict = {
            'simple_tennis_ball': 'assets/ball.urdf',
        }
        self.asset_feature_dict = {}
        self.pts_assets_dict = {}
        # change the primitive ist #
        for p_id, prim in enumerate(primitive_list):
            
            if self.use_multi_objs:
                specified_obj_idx = self.multi_objs_specified_obj_idx[p_id]
            
            if 'cuboid' in prim:
                if self.use_multi_objs:
                    subset_name = prim.split('_')[-1]
                else:
                    subset_name = self.object_type.split('_')[-1]
                cuboids = sorted(glob(f'assets/cuboid/{subset_name}/*.urdf'))
                
                if len(specified_obj_idx) > 0:
                    cuboids = [cuboids[per_specified_obj_idx] for per_specified_obj_idx in specified_obj_idx]
                    print(f"cuboids: {cuboids}")
                
                
                cuboid_list = [f'cuboid_{i}' for i in range(len(cuboids))]
                # print(f"subset_name: {subset_name}, debugging cuboids objs list: {cuboid_list}")
                self.object_type_list += cuboid_list
                for i, name in enumerate(cuboids):
                    self.asset_files_dict[f'cuboid_{i}'] = name.replace('../assets/', '')
                self.object_type_prob += [raw_prob[p_id] / len(cuboid_list) for _ in cuboid_list]
                
                # (0, 0.1, 0), (0, -0.1, 0)
                self.keypoints = torch.tensor([
                    [0, 0.1, 0], [0, -0.1, 0],
                ], dtype=torch.float32,  ).cuda()
            
            elif 'cylinder' in prim:
                if self.use_multi_objs:
                    subset_name = prim.split('_')[-1]
                else:
                    subset_name = self.object_type.split('_')[-1]
                cylinders = sorted(glob(f'assets/cylinder/{subset_name}/*.urdf')) # add mesh file in the assets folder and load their mesh files --- load the mesh vertices # TODO: load asset_obj_file_dict --- cylinder's name to object points #
                
                if len(specified_obj_idx) > 0:
                    cylinders = [cylinders[per_specified_obj_idx] for per_specified_obj_idx in specified_obj_idx]
                
                cylinder_list = [f'cylinder_{i}' for i in range(len(cylinders))]
                self.object_type_list += cylinder_list
                for i, name in enumerate(cylinders):
                    self.asset_files_dict[f'cylinder_{i}'] = name.replace('../assets/', '')
                    
                if self.add_zrot_penalty:
                    for i, name in enumerate(cylinders):
                        pts_fn = name.replace(".urdf", "_sampled_points.npy")
                        pts = np.load(pts_fn) # load points -- nn_points x 3 #
                        self.pts_assets_dict[f'cylinder_{i}'] = pts
                    
                self.object_type_prob += [raw_prob[p_id] / len(cylinder_list) for _ in cylinder_list]
                
                self.keypoints = torch.tensor([
                    [0, 0, 0.1], [0, 0, -0.1],
                ], dtype=torch.float32,  ).cuda()
            
            elif 'grab' in prim:
                grabs = sorted(glob(f'assets/grab/*.urdf'))
                if len(specified_obj_idx) > 0:
                    grabs = [grabs[per_specified_obj_idx] for per_specified_obj_idx in specified_obj_idx]
                grab_list = [f'grab_{i}' for i in range(len(grabs))]
                self.object_type_list += grab_list
                for i, name in enumerate(grabs):
                    self.asset_files_dict[f'grab_{i}'] = name.replace('../assets/', '')
                self.object_type_prob += [raw_prob[p_id] / len(grab_list) for _ in grab_list]
                
            elif 'dexenv' in prim:
                mesh_nm_to_feature = np.load("hora/pc_processing/mesh_nm_to_feature.npy", allow_pickle=True).item()
                
                
                grabs = sorted(glob(f'assets/dexenv/*.urdf'))
                if len(specified_obj_idx) > 0:
                    grabs = [grabs[per_specified_obj_idx] for per_specified_obj_idx in specified_obj_idx]
                
                
                grab_list = [f'dexenv_{i}' for i in range(len(grabs))]
                self.object_type_list += grab_list
                for i, name in enumerate(grabs):
                    self.asset_files_dict[f'dexenv_{i}'] = name.replace('../assets/', '')
                    
                    obj_nm = name.split('/')[-1].split('.')[0]
                    obj_feature = mesh_nm_to_feature[obj_nm]
                    self.asset_feature_dict[f'dexenv_{i}'] = torch.from_numpy(obj_feature).float()
                    
                self.object_type_prob += [raw_prob[p_id] / len(grab_list) for _ in grab_list]
                
                # [46, 49, 52, 66, 68, 56, 10, 102, 117] #
                
            else:
                self.object_type_list += [prim]
                self.object_type_prob += [raw_prob[p_id]]
        print('---- Object List ----')
        print(self.object_type_list)
        assert (len(self.object_type_list) == len(self.object_type_prob))

    def _allocate_task_buffer(self, num_envs): # allocate task buffers # # 
        # extra buffers for observe randomized params #
        self.prop_hist_len = self.config['env']['hora']['propHistoryLen']
        self.num_env_factors = self.config['env']['hora']['privInfoDim']
        self.priv_info_buf = torch.zeros((num_envs, self.num_env_factors), device=self.device, dtype=torch.float)
        self.proprio_hist_buf = torch.zeros((num_envs, self.prop_hist_len, 32), device=self.device, dtype=torch.float)

    def _setup_reward_config(self, r_config):
        self.angvel_clip_min = r_config['angvelClipMin']
        self.angvel_clip_max = r_config['angvelClipMax']
        self.rotate_reward_scale = r_config['rotateRewardScale']
        self.object_linvel_penalty_scale = r_config['objLinvelPenaltyScale']
        self.pose_diff_penalty_scale = r_config['poseDiffPenaltyScale']
        self.torque_penalty_scale = r_config['torquePenaltyScale']
        self.work_penalty_scale = r_config['workPenaltyScale']
        ## NOTE: add the aux pose guidacne configs ##
        self.add_aux_pose_guidance = r_config['addAuxPoseGuidance']
        self.guiding_delta_rot_angle = r_config['guidingDeltaRotAngle']
        self.rot_angle_threshold = r_config['rotAngleThreshold']
        self.guiding_delta_rot_radian = float(self.guiding_delta_rot_angle) / float(180) * np.pi
        self.rot_radian_threshold = float(self.rot_angle_threshold) / float(180) * np.pi
        self.aux_pose_guidance_coef = r_config['auxPoseGuidanceCoef']
        self.add_tilding_termination = r_config['addTildingTermination']
        self.tilding_degree_threshold = r_config['tildingDegreeThreshold']
        self.add_zrot_penalty = r_config['addZRotPenalty']
        self.add_rotp = r_config['addRotp']
        self.rotp_coef = r_config['rotpCoef']
        self.cur_rotp_coef = 0.0
        self.rotp_step = 0
        self.rotp_warmup_steps = r_config['rotpWarmupSteps']
        self.rotp_increasing_steps = r_config['rotpIncreasingSteps']
    

    def _create_object_asset(self):
        # object file to asset
        asset_root = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../')
        
        # hand_asset_file = self.config['env']['asset']['handAsset']
        
        if self.hand_type == 'leap':
            hand_asset_file = 'assets/leap_hand/leap_hand_right.urdf'
        elif self.hand_type == 'allegro_public':
            hand_asset_file = 'assets/allegro/allegro.urdf'
        elif self.hand_type == 'allegro_internal':
            hand_asset_file = 'assets/allegro/allegro_internal.urdf'
        else:
            raise ValueError(f"Unknown hand type: {self.hand_type}")
        
        
        # load hand asset
        hand_asset_options = gymapi.AssetOptions()
        hand_asset_options.flip_visual_attachments = False
        hand_asset_options.fix_base_link = True
        if self.save_init_pose:
            hand_asset_options.collapse_fixed_joints = True
        else:
            hand_asset_options.collapse_fixed_joints = False
        hand_asset_options.disable_gravity = True
        hand_asset_options.thickness = 0.001
        hand_asset_options.angular_damping = 0.01

        if self.torque_control: # 
            hand_asset_options.default_dof_drive_mode = gymapi.DOF_MODE_EFFORT
        else:
            hand_asset_options.default_dof_drive_mode = gymapi.DOF_MODE_POS
        self.hand_asset = self.gym.load_asset(self.sim, asset_root, hand_asset_file, hand_asset_options)
        
        self.fingertip_handles = [self.gym.find_asset_rigid_body_index(self.hand_asset, name) for name in self.fingertips]
        
        num_joints = self.gym.get_asset_joint_count(self.hand_asset)
        self.hand_joint_names = []
        for i in range(num_joints):
            joint_name = self.gym.get_asset_joint_name(self.hand_asset, i)
            self.hand_joint_names.append(joint_name)
        print(f"self.hand_joint_names: {self.hand_joint_names}")
        
        print(f"fingertip_handles: {self.fingertip_handles}")
        
        # if self.hand_tracking: # 
        #     target_hand_asset_options = gymapi.AssetOptions()
        #     target_hand_asset_options.flip_visual_attachments = False
        #     target_hand_asset_options.fix_base_link = True
        #     target_hand_asset_options.collapse_fixed_joints = False
        #     target_hand_asset_options.disable_gravity = True
        #     target_hand_asset_options.thickness = 0.001
        #     target_hand_asset_options.angular_damping = 0.01
        #     target_hand_asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE
        #     self.target_hand_asset = self.gym.load_asset(self.sim, asset_root, hand_asset_file, target_hand_asset_options)
        
        
        # load object asset #
        self.object_asset_list = []
        self.object_points_list = []
        for object_type in self.object_type_list: # object type list #
            object_asset_file = self.asset_files_dict[object_type]
            object_asset_options = gymapi.AssetOptions()
            
            # 
            if self.disable_obj_gravity:
                object_asset_options.disable_gravity = True
                
            if self.enable_vhacd:
                object_asset_options.vhacd_enabled = True
                object_asset_options.vhacd_params = gymapi.VhacdParams()
                object_asset_options.vhacd_params.resolution = 3000000
            
            object_asset = self.gym.load_asset(self.sim, asset_root, object_asset_file, object_asset_options)
            self.num_object_dofs = self.gym.get_asset_dof_count(object_asset)
            self.object_asset_list.append(object_asset)
            
            if self.add_zrot_penalty:
                object_pts = self.pts_assets_dict[object_type] # from object t ype get the oints
                self.object_points_list.append(object_pts)
            
        
        # if self.grasp_to_grasp:
        #     self.target_object_asset_list = []
        #     for object_type in self.object_type_list:
        #         target_object_asset_file = self.asset_files_dict[object_type]
        #         target_object_asset_options = gymapi.AssetOptions()
                
        #         target_object_asset_options.disable_gravity = True
                
        #         target_object_asset = self.gym.load_asset(self.sim, asset_root, target_object_asset_file, target_object_asset_options)
        #         self.target_object_asset_list.append(target_object_asset)
            

    def _init_object_pose(self):
        allegro_hand_start_pose = gymapi.Transform()
        allegro_hand_start_pose.p = gymapi.Vec3(0, 0, 0.5)
        allegro_hand_start_pose.r = gymapi.Quat.from_axis_angle(
            gymapi.Vec3(0, 1, 0), -np.pi / 2) * gymapi.Quat.from_axis_angle(gymapi.Vec3(1, 0, 0), np.pi / 2)
        
        # if self.hand_tracking:
        #     target_allegro_hand_state_pose = gymapi.Transform()
        #     target_allegro_hand_state_pose.p = gymapi.Vec3(0, 0, 2.0)
        #     target_allegro_hand_state_pose.r = gymapi.Quat.from_axis_angle(
        #         gymapi.Vec3(0, 1, 0), -np.pi / 2) * gymapi.Quat.from_axis_angle(gymapi.Vec3(1, 0, 0), np.pi / 2)
        #     self.target_allegro_hand_state_pose = target_allegro_hand_state_pose
        
        
        # object start pos #
        object_start_pose = gymapi.Transform()
        object_start_pose.p = gymapi.Vec3()
        object_start_pose.p.x = allegro_hand_start_pose.p.x
        
        if self.hand_type in ['allegro_public', 'allegro_internal']:
            # pose_dx, pose_dy, pose_dz = -0.01, -0.04, 0.15
            
            pose_dx, pose_dy, pose_dz = -0.01, -0.00, 0.17
        elif self.hand_type == 'leap':
            pose_dx, pose_dy, pose_dz = -0.01, -0.07, 0.15
        else:
            raise ValueError(f"Unknown hand type: {self.hand_type}")

        object_start_pose.p.x = allegro_hand_start_pose.p.x + pose_dx
        object_start_pose.p.y = allegro_hand_start_pose.p.y + pose_dy
        object_start_pose.p.z = allegro_hand_start_pose.p.z + pose_dz
        
        
        # object_start_pose.p.y = allegro_hand_start_pose.p.y - 0.01
        # for grasp pose generation, it is used to initialize the object
        # it should be slightly higher than the fingertip
        # so it is set to be 0.66 for internal allegro and 0.64 for the public allegro
        # ----
        # for in-hand object rotation, the initialization of z is only used in the first step
        # it is set to be 0.65 for backward compatibility
        # object_z = 0.66 if self.save_init_pose else 0.65
        
        if self.hand_type in ['allegro_public', 'allegro_internal']:
            object_z = 0.66 if self.save_init_pose else 0.65
        elif self.hand_type == 'leap':
            # object_z = 0.67 if self.save_init_pose else 0.66
            object_z = self.object_downfacing_init_z
        else:
            raise ValueError(f"Unknown hand type: {self.hand_type}")
        
        if 'internal' not in self.grasp_cache_name:
            object_z -= 0.02
            
        if self.hand_facing_dir == 'down' and (self.grasp_cache_name not in ['leap_change_g_dir']):
            allegro_hand_start_pose.r = gymapi.Quat.from_axis_angle(
                gymapi.Vec3(0, 1, 0), np.pi / 2) * gymapi.Quat.from_axis_angle(gymapi.Vec3(1, 0, 0), np.pi / 2)
            # having tried 0.375, 0.38, 0.4
            object_z = self.object_downfacing_init_z if self.save_init_pose else self.object_downfacing_init_z
            
        object_start_pose.p.z = object_z
        return allegro_hand_start_pose, object_start_pose


def compute_hand_reward(
    object_linvel_penalty, object_linvel_penalty_scale: float,
    rotate_reward, rotate_reward_scale: float,
    pose_diff_penalty, pose_diff_penalty_scale: float,
    torque_penalty, torque_pscale: float,
    work_penalty, work_pscale: float,
):
    reward = rotate_reward_scale * rotate_reward
    # Distance from the hand to the object
    reward = reward + object_linvel_penalty * object_linvel_penalty_scale
    reward = reward + pose_diff_penalty * pose_diff_penalty_scale
    reward = reward + torque_penalty * torque_pscale
    reward = reward + work_penalty * work_pscale
    return reward


def quat_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
    """
    Convert rotations given as quaternions to axis/angle.
    Adapted from PyTorch3D:
    https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#quaternion_to_axis_angle
    Args:
        quaternions: quaternions with real part last,
            as tensor of shape (..., 4).
    Returns:
        Rotations given as a vector in axis angle form, as a tensor
            of shape (..., 3), where the magnitude is the angle
            turned anticlockwise in radians around the vector's
            direction.
    """
    norms = torch.norm(quaternions[..., :3], p=2, dim=-1, keepdim=True)
    half_angles = torch.atan2(norms, quaternions[..., 3:])
    angles = 2 * half_angles
    eps = 1e-6
    small_angles = angles.abs() < eps
    sin_half_angles_over_angles = torch.empty_like(angles)
    sin_half_angles_over_angles[~small_angles] = (
        torch.sin(half_angles[~small_angles]) / angles[~small_angles]
    )
    # for x small, sin(x/2) is about x/2 - (x/2)^3/6
    # so sin(x/2)/x is about 1/2 - (x*x)/48
    sin_half_angles_over_angles[small_angles] = (
        0.5 - (angles[small_angles] * angles[small_angles]) / 48
    )
    return quaternions[..., :3] / sin_half_angles_over_angles

@torch.jit.script
def randomize_rotation_rpy(rand0, rand1, rand2, x_unit_tensor, y_unit_tensor, z_unit_tensor):
    return quat_mul(quat_mul(quat_from_angle_axis(rand0 * np.pi, x_unit_tensor), quat_from_angle_axis(rand1 * np.pi, y_unit_tensor)), quat_from_angle_axis(rand2 * np.pi, z_unit_tensor))

@torch.jit.script
def quaternion_rotation_vector(q, v):
    qx, qy, qz, qw = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
    vx, vy, vz = v[..., 0], v[..., 1], v[..., 2]
    tx = 2 * (qy * vz - qz * vy)
    ty = 2 * (qz * vx - qx * vz)
    tz = 2 * (qx * vy - qy * vx)
    
    vx_new = vx + qw * tx + qy * tz - qz * ty 
    vy_new = vy + qw * ty + qz * tx - qx * tz
    vz_new = vz + qw * tz + qx * ty - qy * tx   
    
    v_new = torch.stack([vx_new, vy_new, vz_new], dim=-1)
    return v_new